structure saas with tools
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# Integrations
|
||||
|
||||
This folder contains logging integrations for litellm
|
||||
|
||||
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
|
||||
@@ -0,0 +1,13 @@
|
||||
# Slack Alerting on LiteLLM Gateway
|
||||
|
||||
This folder contains the Slack Alerting integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
|
||||
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
|
||||
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
|
||||
- `utils.py`: This file contains common utils used specifically for slack alerting
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Handles Batching + sending Httpx Post requests to slack
|
||||
|
||||
Slack alerts are sent every 10s or when events are greater than X events
|
||||
|
||||
see custom_batch_logger.py for more details / defaults
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .slack_alerting import SlackAlerting as _SlackAlerting
|
||||
|
||||
SlackAlertingType = _SlackAlerting
|
||||
else:
|
||||
SlackAlertingType = Any
|
||||
|
||||
|
||||
def squash_payloads(queue):
|
||||
squashed = {}
|
||||
if len(queue) == 0:
|
||||
return squashed
|
||||
if len(queue) == 1:
|
||||
return {"key": {"item": queue[0], "count": 1}}
|
||||
|
||||
for item in queue:
|
||||
url = item["url"]
|
||||
alert_type = item["alert_type"]
|
||||
_key = (url, alert_type)
|
||||
|
||||
if _key in squashed:
|
||||
squashed[_key]["count"] += 1
|
||||
# Merge the payloads
|
||||
|
||||
else:
|
||||
squashed[_key] = {"item": item, "count": 1}
|
||||
|
||||
return squashed
|
||||
|
||||
|
||||
def _print_alerting_payload_warning(
|
||||
payload: dict, slackAlertingInstance: SlackAlertingType
|
||||
):
|
||||
"""
|
||||
Print the payload to the console when
|
||||
slackAlertingInstance.alerting_args.log_to_console is True
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
|
||||
"""
|
||||
if slackAlertingInstance.alerting_args.log_to_console is True:
|
||||
verbose_proxy_logger.warning(payload)
|
||||
|
||||
|
||||
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
|
||||
"""
|
||||
Send a single slack alert to the webhook
|
||||
"""
|
||||
import json
|
||||
|
||||
payload = item.get("payload", {})
|
||||
try:
|
||||
if count > 1:
|
||||
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
|
||||
|
||||
response = await slackAlertingInstance.async_http_handler.post(
|
||||
url=item["url"],
|
||||
headers=item["headers"],
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error sending slack alert to url={item['url']}. Error={response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
|
||||
finally:
|
||||
_print_alerting_payload_warning(
|
||||
payload, slackAlertingInstance=slackAlertingInstance
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Utils used for slack alerting
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.proxy._types import AlertType
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
|
||||
|
||||
Logging = _Logging
|
||||
else:
|
||||
Logging = Any
|
||||
|
||||
|
||||
def process_slack_alerting_variables(
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||
"""
|
||||
process alert_to_webhook_url
|
||||
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
|
||||
"""
|
||||
if alert_to_webhook_url is None:
|
||||
return None
|
||||
|
||||
for alert_type, webhook_urls in alert_to_webhook_url.items():
|
||||
if isinstance(webhook_urls, list):
|
||||
_webhook_values: List[str] = []
|
||||
for webhook_url in webhook_urls:
|
||||
if "os.environ/" in webhook_url:
|
||||
_env_value = get_secret(secret_name=webhook_url)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_values.append(_env_value)
|
||||
else:
|
||||
_webhook_values.append(webhook_url)
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_values
|
||||
else:
|
||||
_webhook_value_str: str = webhook_urls
|
||||
if "os.environ/" in webhook_urls:
|
||||
_env_value = get_secret(secret_name=webhook_urls)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_value_str = _env_value
|
||||
else:
|
||||
_webhook_value_str = webhook_urls
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_value_str
|
||||
|
||||
return alert_to_webhook_url
|
||||
|
||||
|
||||
async def _add_langfuse_trace_id_to_alert(
|
||||
request_data: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
|
||||
- check:
|
||||
-> existing_trace_id
|
||||
-> trace_id
|
||||
-> litellm_call_id
|
||||
"""
|
||||
# do nothing for now
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("litellm_logging_obj", None) is not None
|
||||
):
|
||||
trace_id: Optional[str] = None
|
||||
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
|
||||
|
||||
for _ in range(3):
|
||||
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
|
||||
if trace_id is not None:
|
||||
break
|
||||
await asyncio.sleep(3) # wait 3s before retrying for trace id
|
||||
|
||||
_langfuse_object = litellm_logging_obj._get_callback_object(
|
||||
service_name="langfuse"
|
||||
)
|
||||
if _langfuse_object is not None:
|
||||
base_url = _langfuse_object.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
return None
|
||||
@@ -0,0 +1 @@
|
||||
from . import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,389 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpanAttributes:
|
||||
OUTPUT_VALUE = "output.value"
|
||||
OUTPUT_MIME_TYPE = "output.mime_type"
|
||||
"""
|
||||
The type of output.value. If unspecified, the type is plain text by default.
|
||||
If type is JSON, the value is a string representing a JSON object.
|
||||
"""
|
||||
INPUT_VALUE = "input.value"
|
||||
INPUT_MIME_TYPE = "input.mime_type"
|
||||
"""
|
||||
The type of input.value. If unspecified, the type is plain text by default.
|
||||
If type is JSON, the value is a string representing a JSON object.
|
||||
"""
|
||||
|
||||
EMBEDDING_EMBEDDINGS = "embedding.embeddings"
|
||||
"""
|
||||
A list of objects containing embedding data, including the vector and represented piece of text.
|
||||
"""
|
||||
EMBEDDING_MODEL_NAME = "embedding.model_name"
|
||||
"""
|
||||
The name of the embedding model.
|
||||
"""
|
||||
|
||||
LLM_FUNCTION_CALL = "llm.function_call"
|
||||
"""
|
||||
For models and APIs that support function calling. Records attributes such as the function
|
||||
name and arguments to the called function.
|
||||
"""
|
||||
LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
|
||||
"""
|
||||
Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
|
||||
"""
|
||||
LLM_INPUT_MESSAGES = "llm.input_messages"
|
||||
"""
|
||||
Messages provided to a chat API.
|
||||
"""
|
||||
LLM_OUTPUT_MESSAGES = "llm.output_messages"
|
||||
"""
|
||||
Messages received from a chat API.
|
||||
"""
|
||||
LLM_MODEL_NAME = "llm.model_name"
|
||||
"""
|
||||
The name of the model being used.
|
||||
"""
|
||||
LLM_PROVIDER = "llm.provider"
|
||||
"""
|
||||
The provider of the model, such as OpenAI, Azure, Google, etc.
|
||||
"""
|
||||
LLM_SYSTEM = "llm.system"
|
||||
"""
|
||||
The AI product as identified by the client or server
|
||||
"""
|
||||
LLM_PROMPTS = "llm.prompts"
|
||||
"""
|
||||
Prompts provided to a completions API.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
|
||||
"""
|
||||
The prompt template as a Python f-string.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
|
||||
"""
|
||||
A list of input variables to the prompt template.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
|
||||
"""
|
||||
The version of the prompt template being used.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
|
||||
"""
|
||||
Number of tokens in the prompt.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = "llm.token_count.prompt_details.cache_write"
|
||||
"""
|
||||
Number of tokens in the prompt that were written to cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = "llm.token_count.prompt_details.cache_read"
|
||||
"""
|
||||
Number of tokens in the prompt that were read from cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
|
||||
"""
|
||||
The number of audio input tokens presented in the prompt
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
|
||||
"""
|
||||
Number of tokens in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = "llm.token_count.completion_details.reasoning"
|
||||
"""
|
||||
Number of tokens used for reasoning steps in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = "llm.token_count.completion_details.audio"
|
||||
"""
|
||||
The number of audio input tokens generated by the model
|
||||
"""
|
||||
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
|
||||
"""
|
||||
Total number of tokens, including both prompt and completion.
|
||||
"""
|
||||
|
||||
LLM_TOOLS = "llm.tools"
|
||||
"""
|
||||
List of tools that are advertised to the LLM to be able to call
|
||||
"""
|
||||
|
||||
TOOL_NAME = "tool.name"
|
||||
"""
|
||||
Name of the tool being used.
|
||||
"""
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
"""
|
||||
Description of the tool's purpose, typically used to select the tool.
|
||||
"""
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
"""
|
||||
Parameters of the tool represented a dictionary JSON string, e.g.
|
||||
see https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
"""
|
||||
|
||||
RETRIEVAL_DOCUMENTS = "retrieval.documents"
|
||||
|
||||
METADATA = "metadata"
|
||||
"""
|
||||
Metadata attributes are used to store user-defined key-value pairs.
|
||||
For example, LangChain uses metadata to store user-defined attributes for a chain.
|
||||
"""
|
||||
|
||||
TAG_TAGS = "tag.tags"
|
||||
"""
|
||||
Custom categorical tags for the span.
|
||||
"""
|
||||
|
||||
OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
|
||||
|
||||
SESSION_ID = "session.id"
|
||||
"""
|
||||
The id of the session
|
||||
"""
|
||||
USER_ID = "user.id"
|
||||
"""
|
||||
The id of the user
|
||||
"""
|
||||
|
||||
PROMPT_VENDOR = "prompt.vendor"
|
||||
"""
|
||||
The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
|
||||
"""
|
||||
PROMPT_ID = "prompt.id"
|
||||
"""
|
||||
A vendor-specific id used to locate the prompt.
|
||||
"""
|
||||
PROMPT_URL = "prompt.url"
|
||||
"""
|
||||
A vendor-specific url used to locate the prompt.
|
||||
"""
|
||||
|
||||
|
||||
class MessageAttributes:
|
||||
"""
|
||||
Attributes for a message sent to or from an LLM
|
||||
"""
|
||||
|
||||
MESSAGE_ROLE = "message.role"
|
||||
"""
|
||||
The role of the message, such as "user", "agent", "function".
|
||||
"""
|
||||
MESSAGE_CONTENT = "message.content"
|
||||
"""
|
||||
The content of the message to or from the llm, must be a string.
|
||||
"""
|
||||
MESSAGE_CONTENTS = "message.contents"
|
||||
"""
|
||||
The message contents to the llm, it is an array of
|
||||
`message_content` prefixed attributes.
|
||||
"""
|
||||
MESSAGE_NAME = "message.name"
|
||||
"""
|
||||
The name of the message, often used to identify the function
|
||||
that was used to generate the message.
|
||||
"""
|
||||
MESSAGE_TOOL_CALLS = "message.tool_calls"
|
||||
"""
|
||||
The tool calls generated by the model, such as function calls.
|
||||
"""
|
||||
MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
|
||||
"""
|
||||
The function name that is a part of the message list.
|
||||
This is populated for role 'function' or 'agent' as a mechanism to identify
|
||||
the function that was called during the execution of a tool.
|
||||
"""
|
||||
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
|
||||
"""
|
||||
The JSON string representing the arguments passed to the function
|
||||
during a function call.
|
||||
"""
|
||||
MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
|
||||
|
||||
class MessageContentAttributes:
|
||||
"""
|
||||
Attributes for the contents of user messages sent to an LLM.
|
||||
"""
|
||||
|
||||
MESSAGE_CONTENT_TYPE = "message_content.type"
|
||||
"""
|
||||
The type of the content, such as "text" or "image".
|
||||
"""
|
||||
MESSAGE_CONTENT_TEXT = "message_content.text"
|
||||
"""
|
||||
The text content of the message, if the type is "text".
|
||||
"""
|
||||
MESSAGE_CONTENT_IMAGE = "message_content.image"
|
||||
"""
|
||||
The image content of the message, if the type is "image".
|
||||
An image can be made available to the model by passing a link to
|
||||
the image or by passing the base64 encoded image directly in the
|
||||
request.
|
||||
"""
|
||||
|
||||
|
||||
class ImageAttributes:
|
||||
"""
|
||||
Attributes for images
|
||||
"""
|
||||
|
||||
IMAGE_URL = "image.url"
|
||||
"""
|
||||
An http or base64 image url
|
||||
"""
|
||||
|
||||
|
||||
class AudioAttributes:
|
||||
"""
|
||||
Attributes for audio
|
||||
"""
|
||||
|
||||
AUDIO_URL = "audio.url"
|
||||
"""
|
||||
The url to an audio file
|
||||
"""
|
||||
AUDIO_MIME_TYPE = "audio.mime_type"
|
||||
"""
|
||||
The mime type of the audio file
|
||||
"""
|
||||
AUDIO_TRANSCRIPT = "audio.transcript"
|
||||
"""
|
||||
The transcript of the audio file
|
||||
"""
|
||||
|
||||
|
||||
class DocumentAttributes:
|
||||
"""
|
||||
Attributes for a document.
|
||||
"""
|
||||
|
||||
DOCUMENT_ID = "document.id"
|
||||
"""
|
||||
The id of the document.
|
||||
"""
|
||||
DOCUMENT_SCORE = "document.score"
|
||||
"""
|
||||
The score of the document
|
||||
"""
|
||||
DOCUMENT_CONTENT = "document.content"
|
||||
"""
|
||||
The content of the document.
|
||||
"""
|
||||
DOCUMENT_METADATA = "document.metadata"
|
||||
"""
|
||||
The metadata of the document represented as a dictionary
|
||||
JSON string, e.g. `"{ 'title': 'foo' }"`
|
||||
"""
|
||||
|
||||
|
||||
class RerankerAttributes:
|
||||
"""
|
||||
Attributes for a reranker
|
||||
"""
|
||||
|
||||
RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
|
||||
"""
|
||||
List of documents as input to the reranker
|
||||
"""
|
||||
RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
|
||||
"""
|
||||
List of documents as output from the reranker
|
||||
"""
|
||||
RERANKER_QUERY = "reranker.query"
|
||||
"""
|
||||
Query string for the reranker
|
||||
"""
|
||||
RERANKER_MODEL_NAME = "reranker.model_name"
|
||||
"""
|
||||
Model name of the reranker
|
||||
"""
|
||||
RERANKER_TOP_K = "reranker.top_k"
|
||||
"""
|
||||
Top K parameter of the reranker
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingAttributes:
|
||||
"""
|
||||
Attributes for an embedding
|
||||
"""
|
||||
|
||||
EMBEDDING_TEXT = "embedding.text"
|
||||
"""
|
||||
The text represented by the embedding.
|
||||
"""
|
||||
EMBEDDING_VECTOR = "embedding.vector"
|
||||
"""
|
||||
The embedding vector.
|
||||
"""
|
||||
|
||||
|
||||
class ToolCallAttributes:
|
||||
"""
|
||||
Attributes for a tool call
|
||||
"""
|
||||
|
||||
TOOL_CALL_ID = "tool_call.id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
|
||||
"""
|
||||
The name of function that is being called during a tool call.
|
||||
"""
|
||||
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
|
||||
"""
|
||||
The JSON string representing the arguments passed to the function
|
||||
during a tool call.
|
||||
"""
|
||||
|
||||
|
||||
class ToolAttributes:
|
||||
"""
|
||||
Attributes for a tools
|
||||
"""
|
||||
|
||||
TOOL_JSON_SCHEMA = "tool.json_schema"
|
||||
"""
|
||||
The json schema of a tool input, It is RECOMMENDED that this be in the
|
||||
OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
|
||||
"""
|
||||
|
||||
|
||||
class OpenInferenceSpanKindValues(Enum):
|
||||
TOOL = "TOOL"
|
||||
CHAIN = "CHAIN"
|
||||
LLM = "LLM"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
AGENT = "AGENT"
|
||||
RERANKER = "RERANKER"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
GUARDRAIL = "GUARDRAIL"
|
||||
EVALUATOR = "EVALUATOR"
|
||||
|
||||
|
||||
class OpenInferenceMimeTypeValues(Enum):
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
|
||||
|
||||
class OpenInferenceLLMSystemValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
VERTEXAI = "vertexai"
|
||||
|
||||
|
||||
class OpenInferenceLLMProviderValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
AWS = "aws"
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Base class for Additional Logging Utils for CustomLoggers
|
||||
|
||||
- Health Check for the logging util
|
||||
- Get Request / Response Payload for the logging util
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
|
||||
|
||||
class AdditionalLoggingUtils(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
"""
|
||||
return None
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agentops import AgentOps
|
||||
|
||||
__all__ = ["AgentOps"]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
AgentOps integration for LiteLLM - Provides OpenTelemetry tracing for LLM calls
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
@dataclass
|
||||
class AgentOpsConfig:
|
||||
endpoint: str = "https://otlp.agentops.cloud/v1/traces"
|
||||
api_key: Optional[str] = None
|
||||
service_name: Optional[str] = None
|
||||
deployment_environment: Optional[str] = None
|
||||
auth_endpoint: str = "https://api.agentops.ai/v3/auth/token"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
endpoint="https://otlp.agentops.cloud/v1/traces",
|
||||
api_key=os.getenv("AGENTOPS_API_KEY"),
|
||||
service_name=os.getenv("AGENTOPS_SERVICE_NAME", "agentops"),
|
||||
deployment_environment=os.getenv("AGENTOPS_ENVIRONMENT", "production"),
|
||||
auth_endpoint="https://api.agentops.ai/v3/auth/token"
|
||||
)
|
||||
|
||||
class AgentOps(OpenTelemetry):
|
||||
"""
|
||||
AgentOps integration - built on top of OpenTelemetry
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.success_callback = ["agentops"]
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AgentOpsConfig] = None,
|
||||
):
|
||||
if config is None:
|
||||
config = AgentOpsConfig.from_env()
|
||||
|
||||
# Prefetch JWT token for authentication
|
||||
jwt_token = None
|
||||
project_id = None
|
||||
if config.api_key:
|
||||
try:
|
||||
response = self._fetch_auth_token(config.api_key, config.auth_endpoint)
|
||||
jwt_token = response.get("token")
|
||||
project_id = response.get("project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
headers = f"Authorization=Bearer {jwt_token}" if jwt_token else None
|
||||
|
||||
otel_config = OpenTelemetryConfig(
|
||||
exporter="otlp_http",
|
||||
endpoint=config.endpoint,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Initialize OpenTelemetry with our config
|
||||
super().__init__(
|
||||
config=otel_config,
|
||||
callback_name="agentops"
|
||||
)
|
||||
|
||||
# Set AgentOps-specific resource attributes
|
||||
resource_attrs = {
|
||||
"service.name": config.service_name or "litellm",
|
||||
"deployment.environment": config.deployment_environment or "production",
|
||||
"telemetry.sdk.name": "agentops",
|
||||
}
|
||||
|
||||
if project_id:
|
||||
resource_attrs["project.id"] = project_id
|
||||
|
||||
self.resource_attributes = resource_attrs
|
||||
|
||||
def _fetch_auth_token(self, api_key: str, auth_endpoint: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch JWT authentication token from AgentOps API
|
||||
|
||||
Args:
|
||||
api_key: AgentOps API key
|
||||
auth_endpoint: Authentication endpoint
|
||||
|
||||
Returns:
|
||||
Dict containing JWT token and project ID
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(
|
||||
url=auth_endpoint,
|
||||
headers=headers,
|
||||
json={"api_key": api_key},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch auth token: {response.text}")
|
||||
|
||||
return response.json()
|
||||
finally:
|
||||
client.close()
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||
|
||||
Users can define
|
||||
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
CacheControlMessageInjectionPoint,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Apply cache control directives based on specified injection points.
|
||||
|
||||
Returns:
|
||||
- model: str - the model to use
|
||||
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||
- non_default_params: dict - params with any global cache controls
|
||||
"""
|
||||
# Extract cache control injection points
|
||||
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||
"cache_control_injection_points", []
|
||||
)
|
||||
if not injection_points:
|
||||
return model, messages, non_default_params
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original list
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
|
||||
# Process message-level cache controls
|
||||
for point in injection_points:
|
||||
if point.get("location") == "message":
|
||||
point = cast(CacheControlMessageInjectionPoint, point)
|
||||
processed_messages = self._process_message_injection(
|
||||
point=point, messages=processed_messages
|
||||
)
|
||||
|
||||
return model, processed_messages, non_default_params
|
||||
|
||||
@staticmethod
|
||||
def _process_message_injection(
|
||||
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""Process message-level cache control injection."""
|
||||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
|
||||
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||
targetted_index: Optional[int] = None
|
||||
if isinstance(_targetted_index, str):
|
||||
if _targetted_index.isdigit():
|
||||
targetted_index = int(_targetted_index)
|
||||
else:
|
||||
targetted_index = _targetted_index
|
||||
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
if targetted_index is not None:
|
||||
if 0 <= targetted_index < len(messages):
|
||||
messages[targetted_index] = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
messages[targetted_index], control
|
||||
)
|
||||
)
|
||||
# Case 2: Target by role
|
||||
elif targetted_role is not None:
|
||||
for msg in messages:
|
||||
if msg.get("role") == targetted_role:
|
||||
msg = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
message=msg, control=control
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _safe_insert_cache_control_in_message(
|
||||
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||
) -> AllMessageValues:
|
||||
"""
|
||||
Safe way to insert cache control in a message
|
||||
|
||||
OpenAI Message content can be either:
|
||||
- string
|
||||
- list of objects
|
||||
|
||||
This method handles inserting cache control in both cases.
|
||||
"""
|
||||
message_content = message.get("content", None)
|
||||
|
||||
# 1. if string, insert cache control in the message
|
||||
if isinstance(message_content, str):
|
||||
message["cache_control"] = control # type: ignore
|
||||
# 2. list of objects
|
||||
elif isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
if isinstance(content_item, dict):
|
||||
content_item["cache_control"] = control # type: ignore
|
||||
return message
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic_cache_control_hook"
|
||||
|
||||
@staticmethod
|
||||
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params: Dict,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Send logs to Argilla for annotation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.argilla import (
|
||||
SUPPORTED_PAYLOAD_FIELDS,
|
||||
ArgillaCredentialsObject,
|
||||
ArgillaItem,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
types.FunctionType,
|
||||
types.GeneratorType,
|
||||
BaseModel,
|
||||
)
|
||||
return not isinstance(value, non_serializable_types)
|
||||
|
||||
|
||||
class ArgillaLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
argilla_api_key: Optional[str] = None,
|
||||
argilla_dataset_name: Optional[str] = None,
|
||||
argilla_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if litellm.argilla_transformation_object is None:
|
||||
raise Exception(
|
||||
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
|
||||
)
|
||||
self.validate_argilla_transformation_object(
|
||||
litellm.argilla_transformation_object
|
||||
)
|
||||
self.argilla_transformation_object = litellm.argilla_transformation_object
|
||||
self.default_credentials = self.get_credentials_from_env(
|
||||
argilla_api_key=argilla_api_key,
|
||||
argilla_dataset_name=argilla_dataset_name,
|
||||
argilla_base_url=argilla_base_url,
|
||||
)
|
||||
self.sampling_rate: float = (
|
||||
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
|
||||
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_batch_size = (
|
||||
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
|
||||
)
|
||||
if _batch_size:
|
||||
self.batch_size = int(_batch_size)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
|
||||
def validate_argilla_transformation_object(
|
||||
self, argilla_transformation_object: Dict[str, Any]
|
||||
):
|
||||
if not isinstance(argilla_transformation_object, dict):
|
||||
raise Exception(
|
||||
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
|
||||
)
|
||||
|
||||
for v in argilla_transformation_object.values():
|
||||
if v not in SUPPORTED_PAYLOAD_FIELDS:
|
||||
raise Exception(
|
||||
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
|
||||
)
|
||||
|
||||
def get_credentials_from_env(
|
||||
self,
|
||||
argilla_api_key: Optional[str],
|
||||
argilla_dataset_name: Optional[str],
|
||||
argilla_base_url: Optional[str],
|
||||
) -> ArgillaCredentialsObject:
|
||||
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||
|
||||
_credentials_base_url = (
|
||||
argilla_base_url
|
||||
or os.getenv("ARGILLA_BASE_URL")
|
||||
or "http://localhost:6900/"
|
||||
)
|
||||
if _credentials_base_url is None:
|
||||
raise Exception(
|
||||
"Invalid Argilla Base URL given. _credentials_base_url=None."
|
||||
)
|
||||
|
||||
_credentials_dataset_name = (
|
||||
argilla_dataset_name
|
||||
or os.getenv("ARGILLA_DATASET_NAME")
|
||||
or "litellm-completion"
|
||||
)
|
||||
if _credentials_dataset_name is None:
|
||||
raise Exception("Invalid Argilla Dataset give. Value=None.")
|
||||
else:
|
||||
dataset_response = litellm.module_level_client.get(
|
||||
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
|
||||
headers={"X-Argilla-Api-Key": _credentials_api_key},
|
||||
)
|
||||
json_response = dataset_response.json()
|
||||
if (
|
||||
"items" in json_response
|
||||
and isinstance(json_response["items"], list)
|
||||
and len(json_response["items"]) > 0
|
||||
):
|
||||
_credentials_dataset_name = json_response["items"][0]["id"]
|
||||
|
||||
return ArgillaCredentialsObject(
|
||||
ARGILLA_API_KEY=_credentials_api_key,
|
||||
ARGILLA_BASE_URL=_credentials_base_url,
|
||||
ARGILLA_DATASET_NAME=_credentials_dataset_name,
|
||||
)
|
||||
|
||||
def get_chat_messages(
|
||||
self, payload: StandardLoggingPayload
|
||||
) -> List[Dict[str, Any]]:
|
||||
payload_messages = payload.get("messages", None)
|
||||
|
||||
if payload_messages is None:
|
||||
raise Exception("No chat messages found in payload.")
|
||||
|
||||
if (
|
||||
isinstance(payload_messages, list)
|
||||
and len(payload_messages) > 0
|
||||
and isinstance(payload_messages[0], dict)
|
||||
):
|
||||
return payload_messages
|
||||
elif isinstance(payload_messages, dict):
|
||||
return [payload_messages]
|
||||
else:
|
||||
raise Exception(f"Invalid chat messages format: {payload_messages}")
|
||||
|
||||
def get_str_response(self, payload: StandardLoggingPayload) -> str:
|
||||
response = payload["response"]
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response found in payload.")
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, dict):
|
||||
return (
|
||||
response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid response format: {response}")
|
||||
|
||||
def _prepare_log_data(
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> Optional[ArgillaItem]:
|
||||
try:
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
raise Exception("Error logging request payload. Payload=none.")
|
||||
|
||||
argilla_message = self.get_chat_messages(payload)
|
||||
argilla_response = self.get_str_response(payload)
|
||||
argilla_item: ArgillaItem = {"fields": {}}
|
||||
for k, v in self.argilla_transformation_object.items():
|
||||
if v == "messages":
|
||||
argilla_item["fields"][k] = argilla_message
|
||||
elif v == "response":
|
||||
argilla_item["fields"][k] = argilla_response
|
||||
else:
|
||||
argilla_item["fields"][k] = payload.get(v, None)
|
||||
|
||||
return argilla_item
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = litellm.module_level_client.post(
|
||||
url=url,
|
||||
json=self.log_queue,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
|
||||
self.log_queue.clear()
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = (
|
||||
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
|
||||
except Exception:
|
||||
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger):
|
||||
try:
|
||||
if data is None:
|
||||
break
|
||||
data = await callback.async_dataset_hook(data, payload)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Argilla Layer Error - error logging async success event."
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Langsmith Layer Error - error logging async failure event."
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
sends runs to /batch endpoint
|
||||
|
||||
Sends runs from self.log_queue
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = await self.async_httpx_client.put(
|
||||
url=url,
|
||||
data=json.dumps(
|
||||
{
|
||||
"items": self.log_queue,
|
||||
}
|
||||
),
|
||||
headers=headers,
|
||||
timeout=60000,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"Batch of %s runs successfully created", len(self.log_queue)
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
verbose_logger.exception("Argilla HTTP Error")
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,287 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
def cast_as_primitive_value_type(value) -> Union[str, bool, int, float]:
|
||||
"""
|
||||
Converts a value to an OTEL-supported primitive for Arize/Phoenix observability.
|
||||
"""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (str, bool, int, float)):
|
||||
return value
|
||||
try:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def safe_set_attribute(span: Span, key: str, value: Any):
|
||||
"""
|
||||
Sets a span attribute safely with OTEL-compliant primitive typing for Arize/Phoenix.
|
||||
"""
|
||||
primitive_value = cast_as_primitive_value_type(value)
|
||||
span.set_attribute(key, primitive_value)
|
||||
|
||||
|
||||
def set_attributes(span: Span, kwargs, response_obj): # noqa: PLR0915
|
||||
"""
|
||||
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
|
||||
"""
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
|
||||
try:
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
#############################################
|
||||
|
||||
# Set custom metadata for observability and trace enrichment.
|
||||
metadata = (
|
||||
standard_logging_payload.get("metadata")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
if metadata is not None:
|
||||
safe_set_attribute(span, SpanAttributes.METADATA, safe_dumps(metadata))
|
||||
|
||||
#############################################
|
||||
########## LLM Request Attributes ###########
|
||||
#############################################
|
||||
|
||||
# The name of the LLM a request is being made to.
|
||||
if kwargs.get("model"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_MODEL_NAME,
|
||||
kwargs.get("model"),
|
||||
)
|
||||
|
||||
# The LLM request type.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.type",
|
||||
standard_logging_payload["call_type"],
|
||||
)
|
||||
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_PROVIDER,
|
||||
litellm_params.get("custom_llm_provider", "Unknown"),
|
||||
)
|
||||
|
||||
# The maximum number of tokens the LLM generates for a request.
|
||||
if optional_params.get("max_tokens"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.max_tokens",
|
||||
optional_params.get("max_tokens"),
|
||||
)
|
||||
|
||||
# The temperature setting for the LLM request.
|
||||
if optional_params.get("temperature"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.temperature",
|
||||
optional_params.get("temperature"),
|
||||
)
|
||||
|
||||
# The top_p sampling setting for the LLM request.
|
||||
if optional_params.get("top_p"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.top_p",
|
||||
optional_params.get("top_p"),
|
||||
)
|
||||
|
||||
# Indicates whether response is streamed.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.is_streaming",
|
||||
str(optional_params.get("stream", False)),
|
||||
)
|
||||
|
||||
# Logs the user ID if present.
|
||||
if optional_params.get("user"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.user",
|
||||
optional_params.get("user"),
|
||||
)
|
||||
|
||||
# The unique identifier for the completion.
|
||||
if response_obj and response_obj.get("id"):
|
||||
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
|
||||
|
||||
# The model used to generate the response.
|
||||
if response_obj and response_obj.get("model"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.response.model",
|
||||
response_obj.get("model"),
|
||||
)
|
||||
|
||||
# Required by OpenInference to mark span as LLM kind.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND,
|
||||
OpenInferenceSpanKindValues.LLM.value,
|
||||
)
|
||||
messages = kwargs.get("messages")
|
||||
|
||||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
last_message.get("content", ""),
|
||||
)
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
|
||||
for idx, msg in enumerate(messages):
|
||||
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
|
||||
# Set the role per message.
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
|
||||
)
|
||||
# Set the content per message.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
# Capture tools (function definitions) used in the LLM call.
|
||||
tools = optional_params.get("tools")
|
||||
if tools:
|
||||
for idx, tool in enumerate(tools):
|
||||
function = tool.get("function")
|
||||
if not function:
|
||||
continue
|
||||
prefix = f"{SpanAttributes.LLM_TOOLS}.{idx}"
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{SpanAttributes.TOOL_NAME}", function.get("name")
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{SpanAttributes.TOOL_DESCRIPTION}",
|
||||
function.get("description"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{SpanAttributes.TOOL_PARAMETERS}",
|
||||
json.dumps(function.get("parameters")),
|
||||
)
|
||||
|
||||
# Capture tool calls made during function-calling LLM flows.
|
||||
functions = optional_params.get("functions")
|
||||
if functions:
|
||||
for idx, function in enumerate(functions):
|
||||
prefix = f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
function.get("name"),
|
||||
)
|
||||
|
||||
# Capture invocation parameters and user ID if available.
|
||||
model_params = (
|
||||
standard_logging_payload.get("model_parameters")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
if model_params:
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS,
|
||||
safe_dumps(model_params),
|
||||
)
|
||||
|
||||
if model_params.get("user"):
|
||||
user_id = model_params.get("user")
|
||||
if user_id is not None:
|
||||
safe_set_attribute(span, SpanAttributes.USER_ID, user_id)
|
||||
|
||||
#############################################
|
||||
########## LLM Response Attributes ##########
|
||||
#############################################
|
||||
|
||||
# Captures response tokens, message, and content.
|
||||
if hasattr(response_obj, "get"):
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page.
|
||||
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
# Token usage info.
|
||||
usage = response_obj and response_obj.get("usage")
|
||||
if usage:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
|
||||
usage.get("total_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM response (completion).
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
|
||||
usage.get("completion_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM prompt.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
|
||||
usage.get("prompt_tokens"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
|
||||
)
|
||||
if hasattr(span, "record_exception"):
|
||||
span.record_exception(e)
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
arize AI is OTEL compatible
|
||||
|
||||
this file has Arize ai specific helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.types.integrations.arize import ArizeConfig
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class ArizeLogger(OpenTelemetry):
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_attributes(span: Span, kwargs, response_obj):
|
||||
_utils.set_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_arize_config() -> ArizeConfig:
|
||||
"""
|
||||
Helper function to get Arize configuration.
|
||||
|
||||
Returns:
|
||||
ArizeConfig: A Pydantic model containing Arize configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are not set.
|
||||
"""
|
||||
space_key = os.environ.get("ARIZE_SPACE_KEY")
|
||||
api_key = os.environ.get("ARIZE_API_KEY")
|
||||
|
||||
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
|
||||
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_grpc"
|
||||
|
||||
if grpc_endpoint:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = grpc_endpoint
|
||||
elif http_endpoint:
|
||||
protocol = "otlp_http"
|
||||
endpoint = http_endpoint
|
||||
else:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = "https://otlp.arize.com/v1"
|
||||
|
||||
return ArizeConfig(
|
||||
space_key=space_key,
|
||||
api_key=api_key,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
def create_litellm_proxy_request_started_span(
|
||||
self,
|
||||
start_time: datetime,
|
||||
headers: dict,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
pass
|
||||
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
||||
|
||||
Protocol = _Protocol
|
||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
OpenTelemetryConfig = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
|
||||
|
||||
|
||||
class ArizePhoenixLogger:
|
||||
@staticmethod
|
||||
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
||||
_utils.set_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_arize_phoenix_config() -> ArizePhoenixConfig:
|
||||
"""
|
||||
Retrieves the Arize Phoenix configuration based on environment variables.
|
||||
|
||||
Returns:
|
||||
ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
|
||||
"""
|
||||
api_key = os.environ.get("PHOENIX_API_KEY", None)
|
||||
grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
|
||||
http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_http"
|
||||
|
||||
if http_endpoint:
|
||||
endpoint = http_endpoint
|
||||
protocol = "otlp_http"
|
||||
elif grpc_endpoint:
|
||||
endpoint = grpc_endpoint
|
||||
protocol = "otlp_grpc"
|
||||
else:
|
||||
endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT
|
||||
protocol = "otlp_http"
|
||||
verbose_logger.debug(
|
||||
f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}"
|
||||
)
|
||||
|
||||
otlp_auth_headers = None
|
||||
# If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses
|
||||
# a slightly different auth header format than self hosted phoenix
|
||||
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
|
||||
)
|
||||
otlp_auth_headers = f"api_key={api_key}"
|
||||
elif api_key is not None:
|
||||
# api_key/auth is optional for self hosted phoenix
|
||||
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
||||
|
||||
return ArizePhoenixConfig(
|
||||
otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
import datetime
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class AthinaLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
||||
self.headers = {
|
||||
"athina-api-key": self.athina_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.athina_logging_url = (
|
||||
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
|
||||
+ "/api/v1/log/inference"
|
||||
)
|
||||
self.additional_keys = [
|
||||
"environment",
|
||||
"prompt_slug",
|
||||
"customer_id",
|
||||
"customer_user_id",
|
||||
"session_id",
|
||||
"external_reference_id",
|
||||
"context",
|
||||
"expected_response",
|
||||
"user_query",
|
||||
"tags",
|
||||
"user_feedback",
|
||||
"model_options",
|
||||
"custom_attributes",
|
||||
]
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
import json
|
||||
import traceback
|
||||
|
||||
try:
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# Log the completion response in streaming mode
|
||||
completion_response = kwargs["complete_streaming_response"]
|
||||
response_json = (
|
||||
completion_response.model_dump() if completion_response else {}
|
||||
)
|
||||
else:
|
||||
# Skip logging if the completion response is not available
|
||||
return
|
||||
else:
|
||||
# Log the completion response in non streaming mode
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
"language_model_id": kwargs.get("model"),
|
||||
"request": kwargs,
|
||||
"response": response_json,
|
||||
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"completion_tokens": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
||||
}
|
||||
|
||||
if (
|
||||
type(end_time) is datetime.datetime
|
||||
and type(start_time) is datetime.datetime
|
||||
):
|
||||
data["response_time"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
if "messages" in kwargs:
|
||||
data["prompt"] = kwargs.get("messages", None)
|
||||
|
||||
# Directly add tools or functions if present
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
data.update(
|
||||
(k, v)
|
||||
for k, v in optional_params.items()
|
||||
if k in ["tools", "functions"]
|
||||
)
|
||||
|
||||
# Add additional metadata keys
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
if metadata:
|
||||
for key in self.additional_keys:
|
||||
if key in metadata:
|
||||
data[key] = metadata[key]
|
||||
response = litellm.module_level_client.post(
|
||||
self.athina_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,400 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import _DEFAULT_TTL_FOR_HTTPX_CLIENTS, AZURE_STORAGE_MSFT_VERSION
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token_from_entra_id
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class AzureBlobStorageLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: in init azure blob storage logger"
|
||||
)
|
||||
|
||||
# Env Variables used for Azure Storage Authentication
|
||||
self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID")
|
||||
self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID")
|
||||
self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET")
|
||||
self.azure_storage_account_key: Optional[str] = os.getenv(
|
||||
"AZURE_STORAGE_ACCOUNT_KEY"
|
||||
)
|
||||
|
||||
# Required Env Variables for Azure Storage
|
||||
_azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
|
||||
if not _azure_storage_account_name:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME"
|
||||
)
|
||||
self.azure_storage_account_name: str = _azure_storage_account_name
|
||||
_azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM")
|
||||
if not _azure_storage_file_system:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM"
|
||||
)
|
||||
self.azure_storage_file_system: str = _azure_storage_file_system
|
||||
self._service_client = None
|
||||
# Time that the azure service client expires, in order to reset the connection pool and keep it fresh
|
||||
self._service_client_timeout: Optional[float] = None
|
||||
|
||||
# Internal variables used for Token based authentication
|
||||
self.azure_auth_token: Optional[str] = (
|
||||
None # the Azure AD token to use for Azure Storage API requests
|
||||
)
|
||||
self.token_expiry: Optional[datetime] = (
|
||||
None # the expiry time of the currentAzure AD token
|
||||
)
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
self.log_queue: List[StandardLoggingPayload] = []
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
self._premium_user_check()
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
|
||||
kwargs,
|
||||
)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is not set")
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log failure events to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
self._premium_user_check()
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
|
||||
kwargs,
|
||||
)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is not set")
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the in memory logs queue to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
verbose_logger.exception("Datadog: log_queue does not exist")
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger - about to flush %s events",
|
||||
len(self.log_queue),
|
||||
)
|
||||
|
||||
for payload in self.log_queue:
|
||||
await self.async_upload_payload_to_azure_blob_storage(payload=payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"AzureBlobStorageLogger Error sending batch API - {str(e)}"
|
||||
)
|
||||
|
||||
async def async_upload_payload_to_azure_blob_storage(
|
||||
self, payload: StandardLoggingPayload
|
||||
):
|
||||
"""
|
||||
Uploads the payload to Azure Blob Storage using a 3-step process:
|
||||
1. Create file resource
|
||||
2. Append data
|
||||
3. Flush the data
|
||||
"""
|
||||
try:
|
||||
if self.azure_storage_account_key:
|
||||
await self.upload_to_azure_data_lake_with_azure_account_key(
|
||||
payload=payload
|
||||
)
|
||||
else:
|
||||
# Get a valid token instead of always requesting a new one
|
||||
await self.set_valid_azure_ad_token()
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
json_payload = (
|
||||
json.dumps(payload) + "\n"
|
||||
) # Add newline for each log entry
|
||||
payload_bytes = json_payload.encode("utf-8")
|
||||
filename = f"{payload.get('id') or str(uuid.uuid4())}.json"
|
||||
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}"
|
||||
|
||||
# Execute the 3-step upload process
|
||||
await self._create_file(async_client, base_url)
|
||||
await self._append_data(async_client, base_url, json_payload)
|
||||
await self._flush_data(async_client, base_url, len(payload_bytes))
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully uploaded log to Azure Blob Storage: {filename}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def _create_file(self, client: AsyncHTTPHandler, base_url: str):
|
||||
"""Helper method to create the file resource"""
|
||||
try:
|
||||
verbose_logger.debug(f"Creating file resource at: {base_url}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Length": "0",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.put(f"{base_url}?resource=file", headers=headers)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully created file resource")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file resource: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _append_data(
|
||||
self, client: AsyncHTTPHandler, base_url: str, json_payload: str
|
||||
):
|
||||
"""Helper method to append data to the file"""
|
||||
try:
|
||||
verbose_logger.debug(f"Appending data to file: {base_url}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.patch(
|
||||
f"{base_url}?action=append&position=0",
|
||||
headers=headers,
|
||||
data=json_payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully appended data")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error appending data: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int):
|
||||
"""Helper method to flush the data"""
|
||||
try:
|
||||
verbose_logger.debug(f"Flushing data at position {position}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Length": "0",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.patch(
|
||||
f"{base_url}?action=flush&position={position}", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully flushed data")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error flushing data: {str(e)}")
|
||||
raise
|
||||
|
||||
####### Helper methods to managing Authentication to Azure Storage #######
|
||||
##########################################################################
|
||||
|
||||
async def set_valid_azure_ad_token(self):
|
||||
"""
|
||||
Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary
|
||||
|
||||
Refreshes the token when:
|
||||
- Token is expired
|
||||
- Token is not set
|
||||
"""
|
||||
# Check if token needs refresh
|
||||
if self._azure_ad_token_is_expired() or self.azure_auth_token is None:
|
||||
verbose_logger.debug("Azure AD token needs refresh")
|
||||
self.azure_auth_token = self.get_azure_ad_token_from_azure_storage(
|
||||
tenant_id=self.tenant_id,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
# Token typically expires in 1 hour
|
||||
self.token_expiry = datetime.now() + timedelta(hours=1)
|
||||
verbose_logger.debug(f"New token will expire at {self.token_expiry}")
|
||||
|
||||
def get_azure_ad_token_from_azure_storage(
|
||||
self,
|
||||
tenant_id: Optional[str],
|
||||
client_id: Optional[str],
|
||||
client_secret: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Gets Azure AD token to use for Azure Storage API requests
|
||||
"""
|
||||
verbose_logger.debug("Getting Azure AD Token from Azure Storage")
|
||||
verbose_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret,
|
||||
)
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_TENANT_ID"
|
||||
)
|
||||
if client_id is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_CLIENT_ID"
|
||||
)
|
||||
if client_secret is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
token_provider = get_azure_ad_token_from_entra_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope="https://storage.azure.com/.default",
|
||||
)
|
||||
token = token_provider()
|
||||
|
||||
verbose_logger.debug("azure auth token %s", token)
|
||||
|
||||
return token
|
||||
|
||||
def _azure_ad_token_is_expired(self):
|
||||
"""
|
||||
Returns True if Azure AD token is expired, False otherwise
|
||||
"""
|
||||
if self.azure_auth_token and self.token_expiry:
|
||||
if datetime.now() + timedelta(minutes=5) >= self.token_expiry:
|
||||
verbose_logger.debug("Azure AD token is expired. Requesting new token")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _premium_user_check(self):
|
||||
"""
|
||||
Checks if the user is a premium user, raises an error if not
|
||||
"""
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}"
|
||||
)
|
||||
|
||||
async def get_service_client(self):
|
||||
from azure.storage.filedatalake.aio import DataLakeServiceClient
|
||||
|
||||
# expire old clients to recover from connection issues
|
||||
if (
|
||||
self._service_client_timeout
|
||||
and self._service_client
|
||||
and self._service_client_timeout > time.time()
|
||||
):
|
||||
await self._service_client.close()
|
||||
self._service_client = None
|
||||
if not self._service_client:
|
||||
self._service_client = DataLakeServiceClient(
|
||||
account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net",
|
||||
credential=self.azure_storage_account_key,
|
||||
)
|
||||
self._service_client_timeout = time.time() + _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
return self._service_client
|
||||
|
||||
async def upload_to_azure_data_lake_with_azure_account_key(
|
||||
self, payload: StandardLoggingPayload
|
||||
):
|
||||
"""
|
||||
Uploads the payload to Azure Data Lake using the Azure SDK
|
||||
|
||||
This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API
|
||||
"""
|
||||
|
||||
# Create an async service client
|
||||
|
||||
service_client = await self.get_service_client()
|
||||
# Get file system client
|
||||
file_system_client = service_client.get_file_system_client(
|
||||
file_system=self.azure_storage_file_system
|
||||
)
|
||||
|
||||
try:
|
||||
# Create directory with today's date
|
||||
from datetime import datetime
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
directory_client = file_system_client.get_directory_client(today)
|
||||
|
||||
# check if the directory exists
|
||||
if not await directory_client.exists():
|
||||
await directory_client.create_directory()
|
||||
verbose_logger.debug(f"Created directory: {today}")
|
||||
|
||||
# Create a file client
|
||||
file_name = f"{payload.get('id') or str(uuid.uuid4())}.json"
|
||||
file_client = directory_client.get_file_client(file_name)
|
||||
|
||||
# Create the file
|
||||
await file_client.create_file()
|
||||
|
||||
# Content to append
|
||||
content = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Append content to the file
|
||||
await file_client.append_data(data=content, offset=0, length=len(content))
|
||||
|
||||
# Flush the content to finalize the file
|
||||
await file_client.flush_data(position=len(content), offset=0)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully uploaded and wrote to {today}/{file_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error occurred: {str(e)}")
|
||||
@@ -0,0 +1,450 @@
|
||||
# What is this?
|
||||
## Log success + failure events to Braintrust
|
||||
|
||||
import copy
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
global_braintrust_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
global_braintrust_sync_http_handler = HTTPHandler()
|
||||
API_BASE = "https://api.braintrustdata.com/v1"
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class BraintrustLogger(CustomLogger):
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment(api_key=api_key)
|
||||
self.api_base = api_base or API_BASE
|
||||
self.default_project_id = None
|
||||
self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore
|
||||
self.headers = {
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._project_id_cache: Dict[
|
||||
str, str
|
||||
] = {} # Cache mapping project names to IDs
|
||||
|
||||
def validate_environment(self, api_key: Optional[str]):
|
||||
"""
|
||||
Expects
|
||||
BRAINTRUST_API_KEY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None:
|
||||
missing_keys.append("BRAINTRUST_API_KEY")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def get_project_id_sync(self, project_name: str) -> str:
|
||||
"""
|
||||
Get project ID from name, using cache if available.
|
||||
If project doesn't exist, creates it.
|
||||
"""
|
||||
if project_name in self._project_id_cache:
|
||||
return self._project_id_cache[project_name]
|
||||
|
||||
try:
|
||||
response = global_braintrust_sync_http_handler.post(
|
||||
f"{self.api_base}/project",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
self._project_id_cache[project_name] = project_id
|
||||
return project_id
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Failed to register project: {e.response.text}")
|
||||
|
||||
async def get_project_id_async(self, project_name: str) -> str:
|
||||
"""
|
||||
Async version of get_project_id_sync
|
||||
"""
|
||||
if project_name in self._project_id_cache:
|
||||
return self._project_id_cache[project_name]
|
||||
|
||||
try:
|
||||
response = await global_braintrust_http_handler.post(
|
||||
f"{self.api_base}/project/register",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
self._project_id_cache[project_name] = project_id
|
||||
return project_id
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Failed to register project: {e.response.text}")
|
||||
|
||||
@staticmethod
|
||||
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
|
||||
"""
|
||||
Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_"
|
||||
and overwrites litellm_params.metadata if already included.
|
||||
|
||||
For example if you want to append your trace to an existing `trace_id` via header, send
|
||||
`headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
return metadata
|
||||
|
||||
if litellm_params.get("proxy_server_request") is None:
|
||||
return metadata
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
proxy_headers = (
|
||||
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
)
|
||||
|
||||
for metadata_param_key in proxy_headers:
|
||||
if metadata_param_key.startswith("braintrust"):
|
||||
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
|
||||
if trace_param_key in metadata:
|
||||
verbose_logger.warning(
|
||||
f"Overwriting Braintrust `{trace_param_key}` from request header"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Found Braintrust `{trace_param_key}` in request header"
|
||||
)
|
||||
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
|
||||
|
||||
return metadata
|
||||
|
||||
async def create_default_project_and_experiment(self):
|
||||
project = await global_braintrust_http_handler.post(
|
||||
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
|
||||
)
|
||||
|
||||
project_dict = project.json()
|
||||
|
||||
self.default_project_id = project_dict["id"]
|
||||
|
||||
def create_sync_default_project_and_experiment(self):
|
||||
project = global_braintrust_sync_http_handler.post(
|
||||
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
|
||||
)
|
||||
|
||||
project_dict = project.json()
|
||||
|
||||
self.default_project_id = project_dict["id"]
|
||||
|
||||
def log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||
try:
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
clean_metadata = {}
|
||||
try:
|
||||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except Exception:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = copy.deepcopy(value)
|
||||
metadata = new_metadata
|
||||
|
||||
# Get project_id from metadata or create default if needed
|
||||
project_id = metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = metadata.get("project_name")
|
||||
project_id = (
|
||||
self.get_project_id_sync(project_name) if project_name else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
self.create_sync_default_project_and_experiment()
|
||||
project_id = self.default_project_id
|
||||
|
||||
tags = []
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and key in litellm.langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"headers",
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
if cost is not None:
|
||||
clean_metadata["litellm_response_cost"] = cost
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
"time_to_first_token": end_time.timestamp()
|
||||
- start_time.timestamp(),
|
||||
"start": start_time.timestamp(),
|
||||
"end": end_time.timestamp(),
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"metadata": clean_metadata,
|
||||
"tags": tags,
|
||||
"span_attributes": {"name": "Chat Completion", "type": "llm"},
|
||||
}
|
||||
if choices is not None:
|
||||
request_data["output"] = [choice.dict() for choice in choices]
|
||||
else:
|
||||
request_data["output"] = output
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
|
||||
)
|
||||
global_braintrust_sync_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
headers=self.headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(e.response.text)
|
||||
except Exception as e:
|
||||
raise e # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||
try:
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
clean_metadata = {}
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = value
|
||||
elif isinstance(value, BaseModel):
|
||||
new_metadata[key] = value.model_dump_json()
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if isinstance(v, datetime):
|
||||
value[k] = v.isoformat()
|
||||
new_metadata[key] = value
|
||||
|
||||
# Get project_id from metadata or create default if needed
|
||||
project_id = metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = metadata.get("project_name")
|
||||
project_id = (
|
||||
await self.get_project_id_async(project_name)
|
||||
if project_name
|
||||
else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
await self.create_default_project_and_experiment()
|
||||
project_id = self.default_project_id
|
||||
|
||||
tags = []
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and key in litellm.langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"headers",
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
if cost is not None:
|
||||
clean_metadata["litellm_response_cost"] = cost
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
"start": start_time.timestamp(),
|
||||
"end": end_time.timestamp(),
|
||||
}
|
||||
|
||||
api_call_start_time = kwargs.get("api_call_start_time")
|
||||
completion_start_time = kwargs.get("completion_start_time")
|
||||
|
||||
if (
|
||||
api_call_start_time is not None
|
||||
and completion_start_time is not None
|
||||
):
|
||||
metrics["time_to_first_token"] = (
|
||||
completion_start_time.timestamp()
|
||||
- api_call_start_time.timestamp()
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"output": output,
|
||||
"metadata": clean_metadata,
|
||||
"tags": tags,
|
||||
"span_attributes": {"name": "Chat Completion", "type": "llm"},
|
||||
}
|
||||
if choices is not None:
|
||||
request_data["output"] = [choice.dict() for choice in choices]
|
||||
else:
|
||||
request_data["output"] = output
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
await global_braintrust_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
headers=self.headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(e.response.text)
|
||||
except Exception as e:
|
||||
raise e # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
return super().log_failure_event(kwargs, response_obj, start_time, end_time)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Custom Logger that handles batching logic
|
||||
|
||||
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class CustomBatchLogger(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
flush_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||
"""
|
||||
self.log_queue: List = []
|
||||
self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||
self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_lock = flush_lock
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def periodic_flush(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
verbose_logger.debug(
|
||||
f"CustomLogger periodic flush after {self.flush_interval} seconds"
|
||||
)
|
||||
await self.flush_queue()
|
||||
|
||||
async def flush_queue(self):
|
||||
if self.flush_lock is None:
|
||||
return
|
||||
|
||||
async with self.flush_lock:
|
||||
if self.log_queue:
|
||||
verbose_logger.debug(
|
||||
"CustomLogger: Flushing batch of %s events", len(self.log_queue)
|
||||
)
|
||||
await self.async_send_batch()
|
||||
self.log_queue.clear()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
async def async_send_batch(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -0,0 +1,273 @@
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
|
||||
from litellm.types.utils import StandardLoggingGuardrailInformation
|
||||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
guardrail_name: Optional[str] = None,
|
||||
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
|
||||
event_hook: Optional[
|
||||
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
|
||||
] = None,
|
||||
default_on: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the CustomGuardrail class
|
||||
|
||||
Args:
|
||||
guardrail_name: The name of the guardrail. This is the name used in your requests.
|
||||
supported_event_hooks: The event hooks that the guardrail supports
|
||||
event_hook: The event hook to run the guardrail on
|
||||
default_on: If True, the guardrail will be run by default on all requests
|
||||
"""
|
||||
self.guardrail_name = guardrail_name
|
||||
self.supported_event_hooks = supported_event_hooks
|
||||
self.event_hook: Optional[
|
||||
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
|
||||
] = event_hook
|
||||
self.default_on: bool = default_on
|
||||
|
||||
if supported_event_hooks:
|
||||
## validate event_hook is in supported_event_hooks
|
||||
self._validate_event_hook(event_hook, supported_event_hooks)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _validate_event_hook(
|
||||
self,
|
||||
event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
|
||||
supported_event_hooks: List[GuardrailEventHooks],
|
||||
) -> None:
|
||||
if event_hook is None:
|
||||
return
|
||||
if isinstance(event_hook, list):
|
||||
for hook in event_hook:
|
||||
if hook not in supported_event_hooks:
|
||||
raise ValueError(
|
||||
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
|
||||
)
|
||||
elif isinstance(event_hook, GuardrailEventHooks):
|
||||
if event_hook not in supported_event_hooks:
|
||||
raise ValueError(
|
||||
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
|
||||
)
|
||||
|
||||
def get_guardrail_from_metadata(
|
||||
self, data: dict
|
||||
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
|
||||
"""
|
||||
Returns the guardrail(s) to be run from the metadata
|
||||
"""
|
||||
metadata = data.get("metadata") or {}
|
||||
requested_guardrails = metadata.get("guardrails") or []
|
||||
return requested_guardrails
|
||||
|
||||
def _guardrail_is_in_requested_guardrails(
|
||||
self,
|
||||
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
|
||||
) -> bool:
|
||||
for _guardrail in requested_guardrails:
|
||||
if isinstance(_guardrail, dict):
|
||||
if self.guardrail_name in _guardrail:
|
||||
return True
|
||||
elif isinstance(_guardrail, str):
|
||||
if self.guardrail_name == _guardrail:
|
||||
return True
|
||||
return False
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
"""
|
||||
Returns True if the guardrail should be run on the event_type
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(data)
|
||||
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
|
||||
self.guardrail_name,
|
||||
event_type,
|
||||
self.event_hook,
|
||||
requested_guardrails,
|
||||
self.default_on,
|
||||
)
|
||||
|
||||
if self.default_on is True:
|
||||
if self._event_hook_is_event_type(event_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
if (
|
||||
self.event_hook
|
||||
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
||||
and event_type.value != "logging_only"
|
||||
):
|
||||
return False
|
||||
|
||||
if not self._event_hook_is_event_type(event_type):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
|
||||
"""
|
||||
Returns True if the event_hook is the same as the event_type
|
||||
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
|
||||
"""
|
||||
|
||||
if self.event_hook is None:
|
||||
return True
|
||||
if isinstance(self.event_hook, list):
|
||||
return event_type.value in self.event_hook
|
||||
return self.event_hook == event_type.value
|
||||
|
||||
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Returns `extra_body` to be added to the request body for the Guardrail API call
|
||||
|
||||
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
|
||||
|
||||
```
|
||||
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
|
||||
```
|
||||
|
||||
Will return: for guardrail=`lakera-guard`:
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
|
||||
Args:
|
||||
request_data: The original `request_data` passed to LiteLLM Proxy
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(request_data)
|
||||
|
||||
# Look for the guardrail configuration matching self.guardrail_name
|
||||
for guardrail in requested_guardrails:
|
||||
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
|
||||
# Get the configuration for this guardrail
|
||||
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
|
||||
**guardrail[self.guardrail_name]
|
||||
)
|
||||
if self._validate_premium_user() is not True:
|
||||
return {}
|
||||
|
||||
# Return the extra_body if it exists, otherwise empty dict
|
||||
return guardrail_config.get("extra_body", {})
|
||||
|
||||
return {}
|
||||
|
||||
def _validate_premium_user(self) -> bool:
|
||||
"""
|
||||
Returns True if the user is a premium user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_standard_logging_guardrail_information_to_request_data(
|
||||
self,
|
||||
guardrail_json_response: Union[Exception, str, dict],
|
||||
request_data: dict,
|
||||
guardrail_status: Literal["success", "failure"],
|
||||
) -> None:
|
||||
"""
|
||||
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}"
|
||||
)
|
||||
return
|
||||
if isinstance(guardrail_json_response, Exception):
|
||||
guardrail_json_response = str(guardrail_json_response)
|
||||
slg = StandardLoggingGuardrailInformation(
|
||||
guardrail_name=self.guardrail_name,
|
||||
guardrail_mode=self.event_hook,
|
||||
guardrail_response=guardrail_json_response,
|
||||
guardrail_status=guardrail_status,
|
||||
)
|
||||
if "metadata" in request_data:
|
||||
request_data["metadata"]["standard_logging_guardrail_information"] = slg
|
||||
elif "litellm_metadata" in request_data:
|
||||
request_data["litellm_metadata"][
|
||||
"standard_logging_guardrail_information"
|
||||
] = slg
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
"unable to log guardrail information. No metadata found in request_data"
|
||||
)
|
||||
|
||||
|
||||
def log_guardrail_information(func):
|
||||
"""
|
||||
Decorator to add standard logging guardrail information to any function
|
||||
|
||||
Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
|
||||
|
||||
Logs for:
|
||||
- pre_call
|
||||
- during_call
|
||||
- TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
def process_response(self, response, request_data):
|
||||
self.add_standard_logging_guardrail_information_to_request_data(
|
||||
guardrail_json_response=response,
|
||||
request_data=request_data,
|
||||
guardrail_status="success",
|
||||
)
|
||||
return response
|
||||
|
||||
def process_error(self, e, request_data):
|
||||
self.add_standard_logging_guardrail_information_to_request_data(
|
||||
guardrail_json_response=e,
|
||||
request_data=request_data,
|
||||
guardrail_status="failure",
|
||||
)
|
||||
raise e
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
self: CustomGuardrail = args[0]
|
||||
request_data: Optional[dict] = (
|
||||
kwargs.get("data") or kwargs.get("request_data") or {}
|
||||
)
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
return process_response(self, response, request_data)
|
||||
except Exception as e:
|
||||
return process_error(self, e, request_data)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
self: CustomGuardrail = args[0]
|
||||
request_data: Optional[dict] = (
|
||||
kwargs.get("data") or kwargs.get("request_data") or {}
|
||||
)
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
return process_response(self, response, request_data)
|
||||
except Exception as e:
|
||||
return process_error(self, e, request_data)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,387 @@
|
||||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import traceback
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.argilla import ArgillaItem
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||
from litellm.types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
LLMResponseTypes,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
def __init__(self, message_logging: bool = True) -> None:
|
||||
self.message_logging = message_logging
|
||||
pass
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### PROMPT MANAGEMENT HOOKS ####
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
#### PRE-CALL CHECKS - router/proxy only ####
|
||||
"""
|
||||
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
||||
"""
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
return healthy_deployments
|
||||
|
||||
async def async_pre_call_check(
|
||||
self, deployment: dict, parent_otel_span: Optional[Span]
|
||||
) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
def pre_call_check(self, deployment: dict) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
#### Fallback Events - router/proxy only ####
|
||||
async def log_model_group_rate_limit_error(
|
||||
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
|
||||
):
|
||||
pass
|
||||
|
||||
async def log_success_fallback_event(
|
||||
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
pass
|
||||
|
||||
async def log_failure_fallback_event(
|
||||
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
pass
|
||||
|
||||
#### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
|
||||
|
||||
def translate_completion_input_params(
|
||||
self, kwargs
|
||||
) -> Optional[ChatCompletionRequest]:
|
||||
"""
|
||||
Translates the input params, from the provider's native format to the litellm.completion() format.
|
||||
"""
|
||||
pass
|
||||
|
||||
def translate_completion_output_params(
|
||||
self, response: ModelResponse
|
||||
) -> Optional[BaseModel]:
|
||||
"""
|
||||
Translates the output params, from the OpenAI format to the custom format.
|
||||
"""
|
||||
pass
|
||||
|
||||
def translate_completion_output_params_streaming(
|
||||
self, completion_stream: Any
|
||||
) -> Optional[AdapterCompletionStreamWrapper]:
|
||||
"""
|
||||
Translates the streaming chunk, from the OpenAI format to the custom format.
|
||||
"""
|
||||
pass
|
||||
|
||||
### DATASET HOOKS #### - currently only used for Argilla
|
||||
|
||||
async def async_dataset_hook(
|
||||
self,
|
||||
logged_item: ArgillaItem,
|
||||
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||
) -> Optional[ArgillaItem]:
|
||||
"""
|
||||
- Decide if the result should be logged to Argilla.
|
||||
- Modify the result before logging to Argilla.
|
||||
- Return None if the result should not be logged to Argilla.
|
||||
"""
|
||||
raise NotImplementedError("async_dataset_hook not implemented")
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
"""
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Optional[
|
||||
Union[Exception, str, dict]
|
||||
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||
pass
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: LLMResponseTypes,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
"""For masking logged request/response. Return a modified version of the request/result."""
|
||||
return kwargs, result
|
||||
|
||||
def logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
"""For masking logged request/response. Return a modified version of the request/result."""
|
||||
return kwargs, result
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: str,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Any,
|
||||
request_data: dict,
|
||||
) -> AsyncGenerator[ModelResponseStream, None]:
|
||||
async for item in response:
|
||||
yield item
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(
|
||||
self, model, messages, kwargs, print_verbose, callback_func
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
await callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
def log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
kwargs["log_event_type"] = "post_api_call"
|
||||
callback_func(
|
||||
kwargs, # kwargs to func
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
async def async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
kwargs["log_event_type"] = "post_api_call"
|
||||
await callback_func(
|
||||
kwargs, # kwargs to func
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
# Useful helpers for custom logger classes
|
||||
|
||||
def truncate_standard_logging_payload_content(
|
||||
self,
|
||||
standard_logging_object: StandardLoggingPayload,
|
||||
):
|
||||
"""
|
||||
Truncate error strings and message content in logging payload
|
||||
|
||||
Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB)
|
||||
|
||||
This function truncates the error string and the message content if they exceed a certain length.
|
||||
"""
|
||||
MAX_STR_LENGTH = 10_000
|
||||
|
||||
# Truncate fields that might exceed max length
|
||||
fields_to_truncate = ["error_str", "messages", "response"]
|
||||
for field in fields_to_truncate:
|
||||
self._truncate_field(
|
||||
standard_logging_object=standard_logging_object,
|
||||
field_name=field,
|
||||
max_length=MAX_STR_LENGTH,
|
||||
)
|
||||
|
||||
def _truncate_field(
|
||||
self,
|
||||
standard_logging_object: StandardLoggingPayload,
|
||||
field_name: str,
|
||||
max_length: int,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to truncate a field in the logging payload
|
||||
|
||||
This converts the field to a string and then truncates it if it exceeds the max length.
|
||||
|
||||
Why convert to string ?
|
||||
1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content
|
||||
- Converting to string and then truncating the logged content catches this
|
||||
2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user
|
||||
"""
|
||||
field_value = standard_logging_object.get(field_name) # type: ignore
|
||||
if field_value:
|
||||
str_value = str(field_value)
|
||||
if len(str_value) > max_length:
|
||||
standard_logging_object[field_name] = self._truncate_text( # type: ignore
|
||||
text=str_value, max_length=max_length
|
||||
)
|
||||
|
||||
def _truncate_text(self, text: str, max_length: int) -> str:
|
||||
"""Truncate text if it exceeds max_length"""
|
||||
return (
|
||||
text[:max_length]
|
||||
+ "...truncated by litellm, this logger does not support large content"
|
||||
if len(text) > max_length
|
||||
else text
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class CustomPromptManagement(CustomLogger, PromptManagementBase):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
return "custom-prompt-management"
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: str,
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> PromptManagementClient:
|
||||
raise NotImplementedError(
|
||||
"Custom prompt management does not support compile prompt helper"
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,579 @@
|
||||
"""
|
||||
DataDog Integration - sends logs to /api/v2/log
|
||||
|
||||
DD Reference API: https://docs.datadoghq.com/api/latest/logs
|
||||
|
||||
`async_log_success_event` - used by litellm proxy to send logs to datadog
|
||||
`log_success_event` - sync version of logging to DataDog, only used on litellm Python SDK, if user opts in to using sync functions
|
||||
|
||||
async_log_success_event: will store batch of DD_MAX_BATCH_SIZE in memory and flush to Datadog once it reaches DD_MAX_BATCH_SIZE or every 5 seconds
|
||||
|
||||
async_service_failure_hook: Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
|
||||
|
||||
For batching specific details see CustomBatchLogger class
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime as datetimeObj
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.datadog import *
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from ..additional_logging_utils import AdditionalLoggingUtils
|
||||
|
||||
# max number of logs DD API can accept
|
||||
|
||||
|
||||
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
|
||||
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
|
||||
ServiceTypes.RESET_BUDGET_JOB,
|
||||
]
|
||||
|
||||
|
||||
class DataDogLogger(
|
||||
CustomBatchLogger,
|
||||
AdditionalLoggingUtils,
|
||||
):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the datadog logger, checks if the correct env variables are set
|
||||
|
||||
Required environment variables:
|
||||
`DD_API_KEY` - your datadog api key
|
||||
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug("Datadog: in init datadog logger")
|
||||
# check if the correct env variables are set
|
||||
if os.getenv("DD_API_KEY", None) is None:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
|
||||
if os.getenv("DD_SITE", None) is None:
|
||||
raise Exception("DD_SITE is not set in .env, set 'DD_SITE=<>")
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.DD_API_KEY = os.getenv("DD_API_KEY")
|
||||
self.intake_url = (
|
||||
f"https://http-intake.logs.{os.getenv('DD_SITE')}/api/v2/logs"
|
||||
)
|
||||
|
||||
###################################
|
||||
# OPTIONAL -only used for testing
|
||||
dd_base_url: Optional[str] = (
|
||||
os.getenv("_DATADOG_BASE_URL")
|
||||
or os.getenv("DATADOG_BASE_URL")
|
||||
or os.getenv("DD_BASE_URL")
|
||||
)
|
||||
if dd_base_url is not None:
|
||||
self.intake_url = f"{dd_base_url}/api/v2/logs"
|
||||
###################################
|
||||
self.sync_client = _get_httpx_client()
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
**kwargs, flush_lock=self.flush_lock, batch_size=DD_MAX_BATCH_SIZE
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Got exception on init Datadog client {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Datadog
|
||||
|
||||
- Creates a Datadog payload
|
||||
- Adds the Payload to the in memory logs queue
|
||||
- Payload is flushed every 10 seconds or when batch size is greater than 100
|
||||
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Datadog: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
await self._log_async_event(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Datadog: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
await self._log_async_event(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the in memory logs queue to datadog api
|
||||
|
||||
Logs sent to /api/v2/logs
|
||||
|
||||
DD Ref: https://docs.datadoghq.com/api/latest/logs/
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
verbose_logger.exception("Datadog: log_queue does not exist")
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
"Datadog - about to flush %s events on %s",
|
||||
len(self.log_queue),
|
||||
self.intake_url,
|
||||
)
|
||||
|
||||
response = await self.async_send_compressed_data(self.log_queue)
|
||||
if response.status_code == 413:
|
||||
verbose_logger.exception(DD_ERRORS.DATADOG_413_ERROR.value)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Datadog: Response from datadog API status_code: %s, text: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Error sending batch API - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Sync Log success events to Datadog
|
||||
|
||||
- Creates a Datadog payload
|
||||
- instantly logs it on DD API
|
||||
"""
|
||||
try:
|
||||
if litellm.datadog_use_v1 is True:
|
||||
dd_payload = self._create_v0_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
dd_payload = self.create_datadog_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
response = self.sync_client.post(
|
||||
url=self.intake_url,
|
||||
json=dd_payload, # type: ignore
|
||||
headers={
|
||||
"DD-API-KEY": self.DD_API_KEY,
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Datadog: Response from datadog API status_code: %s, text: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
|
||||
dd_payload = self.create_datadog_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
self.log_queue.append(dd_payload)
|
||||
verbose_logger.debug(
|
||||
f"Datadog, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
def _create_datadog_logging_payload_helper(
|
||||
self,
|
||||
standard_logging_object: StandardLoggingPayload,
|
||||
status: DataDogStatus,
|
||||
) -> DatadogPayload:
|
||||
json_payload = json.dumps(standard_logging_object, default=str)
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(
|
||||
standard_logging_object=standard_logging_object
|
||||
),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=self._get_datadog_service(),
|
||||
status=status,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
def create_datadog_logging_payload(
|
||||
self,
|
||||
kwargs: Union[dict, Any],
|
||||
response_obj: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> DatadogPayload:
|
||||
"""
|
||||
Helper function to create a datadog payload for logging
|
||||
|
||||
Args:
|
||||
kwargs (Union[dict, Any]): request kwargs
|
||||
response_obj (Any): llm api response
|
||||
start_time (datetime.datetime): start time of request
|
||||
end_time (datetime.datetime): end time of request
|
||||
|
||||
Returns:
|
||||
DatadogPayload: defined in types.py
|
||||
"""
|
||||
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
status = DataDogStatus.INFO
|
||||
if standard_logging_object.get("status") == "failure":
|
||||
status = DataDogStatus.ERROR
|
||||
|
||||
# Build the initial payload
|
||||
self.truncate_standard_logging_payload_content(standard_logging_object)
|
||||
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=status,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
async def async_send_compressed_data(self, data: List) -> Response:
|
||||
"""
|
||||
Async helper to send compressed data to datadog self.intake_url
|
||||
|
||||
Datadog recommends using gzip to compress data
|
||||
https://docs.datadoghq.com/api/latest/logs/
|
||||
|
||||
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
|
||||
compressed_data = gzip.compress(json.dumps(data, default=str).encode("utf-8"))
|
||||
response = await self.async_client.post(
|
||||
url=self.intake_url,
|
||||
data=compressed_data, # type: ignore
|
||||
headers={
|
||||
"DD-API-KEY": self.DD_API_KEY,
|
||||
"Content-Encoding": "gzip",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Any] = None,
|
||||
start_time: Optional[Union[datetimeObj, float]] = None,
|
||||
end_time: Optional[Union[float, datetimeObj]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
|
||||
|
||||
- example - Redis is failing / erroring, will be logged on DataDog
|
||||
"""
|
||||
try:
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
_dd_message_str = json.dumps(_payload_dict, default=str)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service=self._get_datadog_service(),
|
||||
status=DataDogStatus.WARN,
|
||||
)
|
||||
|
||||
self.log_queue.append(_dd_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Any] = None,
|
||||
start_time: Optional[Union[datetimeObj, float]] = None,
|
||||
end_time: Optional[Union[float, datetimeObj]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Logs success from Redis, Postgres (Adjacent systems), as 'INFO' on DataDog
|
||||
|
||||
No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
|
||||
"""
|
||||
try:
|
||||
# intentionally done. Don't want to log all service types to DD
|
||||
if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
|
||||
return
|
||||
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
|
||||
_dd_message_str = json.dumps(_payload_dict, default=str)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service=self._get_datadog_service(),
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
|
||||
self.log_queue.append(_dd_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
|
||||
)
|
||||
|
||||
def _create_v0_logging_payload(
|
||||
self,
|
||||
kwargs: Union[dict, Any],
|
||||
response_obj: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> DatadogPayload:
|
||||
"""
|
||||
Note: This is our V1 Version of DataDog Logging Payload
|
||||
|
||||
|
||||
(Not Recommended) If you want this to get logged set `litellm.datadog_use_v1 = True`
|
||||
"""
|
||||
import json
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
usage = dict(usage)
|
||||
try:
|
||||
response_time = (end_time - start_time).total_seconds() * 1000
|
||||
except Exception:
|
||||
response_time = None
|
||||
|
||||
try:
|
||||
response_obj = dict(response_obj)
|
||||
except Exception:
|
||||
response_obj = response_obj
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"cache_hit": cache_hit,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"response_time": response_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"model_parameters": optional_params,
|
||||
"spend": kwargs.get("response_cost", 0),
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": clean_metadata,
|
||||
}
|
||||
|
||||
json_payload = json.dumps(payload, default=str)
|
||||
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=self._get_datadog_service(),
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_tags(
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the datadog tags for the request
|
||||
|
||||
DD tags need to be as follows:
|
||||
- tags: ["user_handle:dog@gmail.com", "app_version:1.0.0"]
|
||||
"""
|
||||
base_tags = {
|
||||
"env": os.getenv("DD_ENV", "unknown"),
|
||||
"service": os.getenv("DD_SERVICE", "litellm"),
|
||||
"version": os.getenv("DD_VERSION", "unknown"),
|
||||
"HOSTNAME": DataDogLogger._get_datadog_hostname(),
|
||||
"POD_NAME": os.getenv("POD_NAME", "unknown"),
|
||||
}
|
||||
|
||||
tags = [f"{k}:{v}" for k, v in base_tags.items()]
|
||||
|
||||
if standard_logging_object:
|
||||
_request_tags: List[str] = (
|
||||
standard_logging_object.get("request_tags", []) or []
|
||||
)
|
||||
request_tags = [f"request_tag:{tag}" for tag in _request_tags]
|
||||
tags.extend(request_tags)
|
||||
|
||||
return ",".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_source():
|
||||
return os.getenv("DD_SOURCE", "litellm")
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_service():
|
||||
return os.getenv("DD_SERVICE", "litellm-server")
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_hostname():
|
||||
return os.getenv("HOSTNAME", "")
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_env():
|
||||
return os.getenv("DD_ENV", "unknown")
|
||||
|
||||
@staticmethod
|
||||
def _get_datadog_pod_name():
|
||||
return os.getenv("POD_NAME", "unknown")
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
create_dummy_standard_logging_payload,
|
||||
)
|
||||
|
||||
standard_logging_object = create_dummy_standard_logging_payload()
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
log_queue = [dd_payload]
|
||||
response = await self.async_send_compressed_data(log_queue)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="healthy",
|
||||
error_message=None,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetimeObj],
|
||||
end_time_utc: Optional[datetimeObj],
|
||||
) -> Optional[dict]:
|
||||
pass
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Implements logging integration with Datadog's LLM Observability Service
|
||||
|
||||
|
||||
API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_any_messages_to_chat_completion_str_messages_conversion,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.datadog_llm_obs import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class DataDogLLMObsLogger(DataDogLogger, CustomBatchLogger):
|
||||
def __init__(self, **kwargs):
|
||||
try:
|
||||
verbose_logger.debug("DataDogLLMObs: Initializing logger")
|
||||
if os.getenv("DD_API_KEY", None) is None:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
|
||||
if os.getenv("DD_SITE", None) is None:
|
||||
raise Exception(
|
||||
"DD_SITE is not set, set 'DD_SITE=<>', example sit = `us5.datadoghq.com`"
|
||||
)
|
||||
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.DD_API_KEY = os.getenv("DD_API_KEY")
|
||||
self.DD_SITE = os.getenv("DD_SITE")
|
||||
self.intake_url = (
|
||||
f"https://api.{self.DD_SITE}/api/intake/llm-obs/v1/trace/spans"
|
||||
)
|
||||
|
||||
# testing base url
|
||||
dd_base_url = os.getenv("DD_BASE_URL")
|
||||
if dd_base_url:
|
||||
self.intake_url = f"{dd_base_url}/api/intake/llm-obs/v1/trace/spans"
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
self.log_queue: List[LLMObsPayload] = []
|
||||
CustomBatchLogger.__init__(self, **kwargs, flush_lock=self.flush_lock)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"DataDogLLMObs: Error initializing - {str(e)}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}"
|
||||
)
|
||||
payload = self.create_llm_obs_payload(
|
||||
kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
|
||||
self.log_queue.append(payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"DataDogLLMObs: Error logging success event - {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Flushing {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
# Prepare the payload
|
||||
payload = {
|
||||
"data": DDIntakePayload(
|
||||
type="span",
|
||||
attributes=DDSpanAttributes(
|
||||
ml_app=self._get_datadog_service(),
|
||||
tags=[self._get_datadog_tags()],
|
||||
spans=self.log_queue,
|
||||
),
|
||||
),
|
||||
}
|
||||
verbose_logger.debug("payload %s", json.dumps(payload, indent=4))
|
||||
response = await self.async_client.post(
|
||||
url=self.intake_url,
|
||||
json=payload,
|
||||
headers={
|
||||
"DD-API-KEY": self.DD_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"DataDogLLMObs: Unexpected response - status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Successfully sent batch - status_code: {response.status_code}"
|
||||
)
|
||||
self.log_queue.clear()
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
f"DataDogLLMObs: Error sending batch - {e.response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"DataDogLLMObs: Error sending batch - {str(e)}")
|
||||
|
||||
def create_llm_obs_payload(
|
||||
self, kwargs: Dict, response_obj: Any, start_time: datetime, end_time: datetime
|
||||
) -> LLMObsPayload:
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise Exception("DataDogLLMObs: standard_logging_object is not set")
|
||||
|
||||
messages = standard_logging_payload["messages"]
|
||||
messages = self._ensure_string_content(messages=messages)
|
||||
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
|
||||
input_meta = InputMeta(
|
||||
messages=handle_any_messages_to_chat_completion_str_messages_conversion(
|
||||
messages
|
||||
)
|
||||
)
|
||||
output_meta = OutputMeta(messages=self._get_response_messages(response_obj))
|
||||
|
||||
meta = Meta(
|
||||
kind="llm",
|
||||
input=input_meta,
|
||||
output=output_meta,
|
||||
metadata=self._get_dd_llm_obs_payload_metadata(standard_logging_payload),
|
||||
)
|
||||
|
||||
# Calculate metrics (you may need to adjust these based on available data)
|
||||
metrics = LLMMetrics(
|
||||
input_tokens=float(standard_logging_payload.get("prompt_tokens", 0)),
|
||||
output_tokens=float(standard_logging_payload.get("completion_tokens", 0)),
|
||||
total_tokens=float(standard_logging_payload.get("total_tokens", 0)),
|
||||
)
|
||||
|
||||
return LLMObsPayload(
|
||||
parent_id=metadata.get("parent_id", "undefined"),
|
||||
trace_id=metadata.get("trace_id", str(uuid.uuid4())),
|
||||
span_id=metadata.get("span_id", str(uuid.uuid4())),
|
||||
name=metadata.get("name", "litellm_llm_call"),
|
||||
meta=meta,
|
||||
start_ns=int(start_time.timestamp() * 1e9),
|
||||
duration=int((end_time - start_time).total_seconds() * 1e9),
|
||||
metrics=metrics,
|
||||
tags=[
|
||||
self._get_datadog_tags(standard_logging_object=standard_logging_payload)
|
||||
],
|
||||
)
|
||||
|
||||
def _get_response_messages(self, response_obj: Any) -> List[Any]:
|
||||
"""
|
||||
Get the messages from the response object
|
||||
|
||||
for now this handles logging /chat/completions responses
|
||||
"""
|
||||
if isinstance(response_obj, litellm.ModelResponse):
|
||||
return [response_obj["choices"][0]["message"].json()]
|
||||
return []
|
||||
|
||||
def _ensure_string_content(
|
||||
self, messages: Optional[Union[str, List[Any], Dict[Any, Any]]]
|
||||
) -> List[Any]:
|
||||
if messages is None:
|
||||
return []
|
||||
if isinstance(messages, str):
|
||||
return [messages]
|
||||
elif isinstance(messages, list):
|
||||
return [message for message in messages]
|
||||
elif isinstance(messages, dict):
|
||||
return [str(messages.get("content", ""))]
|
||||
return []
|
||||
|
||||
def _get_dd_llm_obs_payload_metadata(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> Dict:
|
||||
_metadata = {
|
||||
"model_name": standard_logging_payload.get("model", "unknown"),
|
||||
"model_provider": standard_logging_payload.get(
|
||||
"custom_llm_provider", "unknown"
|
||||
),
|
||||
}
|
||||
_standard_logging_metadata: dict = (
|
||||
dict(standard_logging_payload.get("metadata", {})) or {}
|
||||
)
|
||||
_metadata.update(_standard_logging_metadata)
|
||||
return _metadata
|
||||
@@ -0,0 +1,89 @@
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class DyanmoDBLogger:
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
import boto3
|
||||
|
||||
self.dynamodb: Any = boto3.resource(
|
||||
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
||||
)
|
||||
if litellm.dynamodb_table_name is None:
|
||||
raise ValueError(
|
||||
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
|
||||
)
|
||||
self.table_name = litellm.dynamodb_table_name
|
||||
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose
|
||||
):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
print_verbose(
|
||||
f"DynamoDB Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# construct payload to send to DynamoDB
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
|
||||
|
||||
# put data in dyanmo DB
|
||||
table = self.dynamodb.Table(self.table_name)
|
||||
# Assuming log_data is a dictionary with log information
|
||||
response = table.put_item(Item=payload)
|
||||
|
||||
print_verbose(f"Response from DynamoDB:{str(response)}")
|
||||
|
||||
print_verbose(
|
||||
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
return response
|
||||
except Exception:
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Functions for sending Email Alerts
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.proxy._types import WebhookEvent
|
||||
|
||||
# we use this for the email header, please send a test email if you change this. verify it looks good on email
|
||||
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
LITELLM_SUPPORT_CONTACT = "support@berri.ai"
|
||||
|
||||
|
||||
async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Getting all team members for team_id=%s", team_id
|
||||
)
|
||||
if team_id is None:
|
||||
return []
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
return []
|
||||
|
||||
_team_members = team_row.members_with_roles
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Got team members for team_id=%s Team Members: %s",
|
||||
team_id,
|
||||
_team_members,
|
||||
)
|
||||
_team_member_user_ids: List[str] = []
|
||||
for member in _team_members:
|
||||
if member and isinstance(member, dict):
|
||||
_user_id = member.get("user_id")
|
||||
if _user_id and isinstance(_user_id, str):
|
||||
_team_member_user_ids.append(_user_id)
|
||||
|
||||
sql_query = """
|
||||
SELECT user_email
|
||||
FROM "LiteLLM_UserTable"
|
||||
WHERE user_id = ANY($1::TEXT[]);
|
||||
"""
|
||||
|
||||
_result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
|
||||
|
||||
verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
|
||||
|
||||
if _result is None:
|
||||
return []
|
||||
|
||||
emails = []
|
||||
for user in _result:
|
||||
if user and isinstance(user, dict) and user.get("user_email", None) is not None:
|
||||
emails.append(user.get("user_email"))
|
||||
return emails
|
||||
|
||||
|
||||
async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
|
||||
"""
|
||||
Send an Email Alert to All Team Members when the Team Budget is crossed
|
||||
Returns -> True if sent, False if not.
|
||||
"""
|
||||
from litellm.proxy.utils import send_email
|
||||
|
||||
_team_id = webhook_event.team_id
|
||||
team_alias = webhook_event.team_alias
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Sending Team Budget Alert for team=%s", team_alias
|
||||
)
|
||||
|
||||
email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
|
||||
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
|
||||
|
||||
# await self._check_if_using_premium_email_feature(
|
||||
# premium_user, email_logo_url, email_support_contact
|
||||
# )
|
||||
|
||||
if email_logo_url is None:
|
||||
email_logo_url = LITELLM_LOGO_URL
|
||||
if email_support_contact is None:
|
||||
email_support_contact = LITELLM_SUPPORT_CONTACT
|
||||
recipient_emails = await get_all_team_member_emails(_team_id)
|
||||
recipient_emails_str: str = ",".join(recipient_emails)
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Sending team budget alert to %s", recipient_emails_str
|
||||
)
|
||||
|
||||
event_name = webhook_event.event_message
|
||||
max_budget = webhook_event.max_budget
|
||||
email_html_content = "Alert from LiteLLM Server"
|
||||
|
||||
if recipient_emails_str is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
|
||||
recipient_emails_str,
|
||||
)
|
||||
|
||||
email_html_content = f"""
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
|
||||
|
||||
Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
|
||||
|
||||
Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
|
||||
|
||||
API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
email_event = {
|
||||
"to": recipient_emails_str,
|
||||
"subject": f"LiteLLM {event_name} for Team {team_alias}",
|
||||
"html": email_html_content,
|
||||
}
|
||||
|
||||
await send_email(
|
||||
receiver_email=email_event["to"],
|
||||
subject=email_event["subject"],
|
||||
html=email_event["html"],
|
||||
)
|
||||
|
||||
return False
|
||||
Binary file not shown.
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Email Templates used by the LiteLLM Email Service in slack_alerting.py
|
||||
"""
|
||||
|
||||
KEY_CREATED_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
I'm happy to provide you with an OpenAI Proxy API Key, loaded with ${key_budget} per month. <br /> <br />
|
||||
|
||||
<b>
|
||||
Key: <pre>{key_token}</pre> <br>
|
||||
</b>
|
||||
|
||||
<h2>Usage Example</h2>
|
||||
|
||||
Detailed Documentation on <a href="https://docs.litellm.ai/docs/proxy/user_keys">Usage with OpenAI Python SDK, Langchain, LlamaIndex, Curl</a>
|
||||
|
||||
<pre>
|
||||
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="{key_token}",
|
||||
base_url={{base_url}}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo", # model to send to the proxy
|
||||
messages = [
|
||||
{{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}}
|
||||
]
|
||||
)
|
||||
|
||||
</pre>
|
||||
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
|
||||
USER_INVITED_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
You were invited to use OpenAI Proxy API for team {team_name} <br /> <br />
|
||||
|
||||
<a href="{base_url}" style="display: inline-block; padding: 10px 20px; background-color: #87ceeb; color: #fff; text-decoration: none; border-radius: 20px;">Get Started here</a> <br /> <br />
|
||||
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
|
||||
class LLMResponse(BaseModel):
|
||||
latency_ms: int
|
||||
status_code: int
|
||||
input_text: str
|
||||
output_text: str
|
||||
node_type: str
|
||||
model: str
|
||||
num_input_tokens: int
|
||||
num_output_tokens: int
|
||||
output_logprobs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional. When available, logprobs are used to compute Uncertainty.",
|
||||
)
|
||||
created_at: str = Field(
|
||||
..., description='timestamp constructed in "%Y-%m-%dT%H:%M:%S" format'
|
||||
)
|
||||
tags: Optional[List[str]] = None
|
||||
user_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GalileoObserve(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
self.in_memory_records: List[dict] = []
|
||||
self.batch_size = 1
|
||||
self.base_url = os.getenv("GALILEO_BASE_URL", None)
|
||||
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
|
||||
self.headers: Optional[Dict[str, str]] = None
|
||||
self.async_httpx_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
pass
|
||||
|
||||
def set_galileo_headers(self):
|
||||
# following https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#logging-your-records
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
galileo_login_response = litellm.module_level_client.post(
|
||||
url=f"{self.base_url}/login",
|
||||
headers=headers,
|
||||
data={
|
||||
"username": os.getenv("GALILEO_USERNAME"),
|
||||
"password": os.getenv("GALILEO_PASSWORD"),
|
||||
},
|
||||
)
|
||||
|
||||
access_token = galileo_login_response.json()["access_token"]
|
||||
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
def get_output_str_from_response(self, response_obj, kwargs):
|
||||
output = None
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
return output
|
||||
|
||||
async def async_log_success_event(
|
||||
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
|
||||
):
|
||||
verbose_logger.debug("On Async Success")
|
||||
|
||||
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
|
||||
_call_type = kwargs.get("call_type", "litellm")
|
||||
input_text = litellm.utils.get_formatted_prompt(
|
||||
data=kwargs, call_type=_call_type
|
||||
)
|
||||
|
||||
_usage = response_obj.get("usage", {}) or {}
|
||||
num_input_tokens = _usage.get("prompt_tokens", 0)
|
||||
num_output_tokens = _usage.get("completion_tokens", 0)
|
||||
|
||||
output_text = self.get_output_str_from_response(
|
||||
response_obj=response_obj, kwargs=kwargs
|
||||
)
|
||||
|
||||
if output_text is not None:
|
||||
request_record = LLMResponse(
|
||||
latency_ms=_latency_ms,
|
||||
status_code=200,
|
||||
input_text=input_text,
|
||||
output_text=output_text,
|
||||
node_type=_call_type,
|
||||
model=kwargs.get("model", "-"),
|
||||
num_input_tokens=num_input_tokens,
|
||||
num_output_tokens=num_output_tokens,
|
||||
created_at=start_time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S"
|
||||
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
|
||||
)
|
||||
|
||||
# dump to dict
|
||||
request_dict = request_record.model_dump()
|
||||
self.in_memory_records.append(request_dict)
|
||||
|
||||
if len(self.in_memory_records) >= self.batch_size:
|
||||
await self.flush_in_memory_records()
|
||||
|
||||
async def flush_in_memory_records(self):
|
||||
verbose_logger.debug("flushing in memory records")
|
||||
response = await self.async_httpx_handler.post(
|
||||
url=f"{self.base_url}/projects/{self.project_id}/observe/ingest",
|
||||
headers=self.headers,
|
||||
json={"records": self.in_memory_records},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
verbose_logger.debug(
|
||||
"Galileo Logger:successfully flushed in memory records"
|
||||
)
|
||||
self.in_memory_records = []
|
||||
else:
|
||||
verbose_logger.debug("Galileo Logger: failed to flush in memory records")
|
||||
verbose_logger.debug(
|
||||
"Galileo Logger error=%s, status code=%s",
|
||||
response.text,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
verbose_logger.debug("On Async Failure")
|
||||
@@ -0,0 +1,12 @@
|
||||
# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
|
||||
|
||||
This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
|
||||
- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/bucket)
|
||||
- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,234 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
|
||||
# Init Batch logging settings
|
||||
self.log_queue: List[GCSLogQueueItem] = []
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
AdditionalLoggingUtils.__init__(self)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
#### ASYNC ####
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Process queued logs in batch - sends logs to GCS Bucket
|
||||
|
||||
|
||||
GCS Bucket does not have a Batch endpoint to batch upload logs
|
||||
|
||||
Instead, we
|
||||
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
|
||||
- during async_send_batch, we make 1 POST request per log to GCS Bucket
|
||||
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
for log_item in self.log_queue:
|
||||
logging_payload = log_item["payload"]
|
||||
kwargs = log_item["kwargs"]
|
||||
response_obj = log_item.get("response_obj", None) or {}
|
||||
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
object_name = self._get_object_name(kwargs, logging_payload, response_obj)
|
||||
|
||||
try:
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
except Exception as e:
|
||||
# don't let one log item fail the entire batch
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
pass
|
||||
|
||||
# Clear the queue after processing
|
||||
self.log_queue.clear()
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = self._generate_failure_object_name(
|
||||
request_date_str=current_date,
|
||||
)
|
||||
else:
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=current_date,
|
||||
response_id=response_obj.get("id", ""),
|
||||
)
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
if "gcs_log_id" in _metadata:
|
||||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
Tries current day, next day, and previous day until it finds the payload
|
||||
"""
|
||||
if start_time_utc is None:
|
||||
raise ValueError(
|
||||
"start_time_utc is required for getting a payload from GCS Bucket"
|
||||
)
|
||||
|
||||
# Try current day, next day, and previous day
|
||||
dates_to_try = [
|
||||
start_time_utc,
|
||||
start_time_utc + timedelta(days=1),
|
||||
start_time_utc - timedelta(days=1),
|
||||
]
|
||||
date_str = None
|
||||
for date in dates_to_try:
|
||||
try:
|
||||
date_str = self._get_object_date_from_datetime(datetime_obj=date)
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=date_str,
|
||||
response_id=request_id,
|
||||
)
|
||||
encoded_object_name = quote(object_name, safe="")
|
||||
response = await self.download_gcs_object(encoded_object_name)
|
||||
|
||||
if response is not None:
|
||||
loaded_response = json.loads(response)
|
||||
return loaded_response
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch payload for date {date_str}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _generate_success_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
response_id: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/{response_id}"
|
||||
|
||||
def _generate_failure_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
|
||||
return datetime_obj.strftime("%Y-%m-%d")
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
raise NotImplementedError("GCS Bucket does not support health check")
|
||||
@@ -0,0 +1,326 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
|
||||
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
|
||||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
self,
|
||||
service_account_json: Optional[str],
|
||||
vertex_instance: Optional[VertexBase] = None,
|
||||
) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
if vertex_instance is None:
|
||||
vertex_instance = vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
|
||||
credentials=service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_instance._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
"""
|
||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||
If no dynamic parameters are provided, it uses the default values.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||
)
|
||||
|
||||
_bucket_name: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
)
|
||||
_path_service_account: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
)
|
||||
|
||||
if _bucket_name is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
else:
|
||||
# If no dynamic parameters, use the default instance
|
||||
if self.BUCKET_NAME is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
|
||||
return GCSLoggingConfig(
|
||||
bucket_name=bucket_name,
|
||||
vertex_instance=vertex_instance,
|
||||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"GCS object download error: %s", str(response.text)
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object download response status code: %s", response.status_code
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Delete an object from GCS.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||
|
||||
if (response.status_code != 200) or (response.status_code != 204):
|
||||
verbose_logger.error(
|
||||
"GCS object delete error: %s, status code: %s",
|
||||
str(response.text),
|
||||
response.status_code,
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object delete response status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: Union[StandardLoggingPayload, str],
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
if isinstance(logging_payload, str):
|
||||
json_logged_payload = logging_payload
|
||||
else:
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
return response.json()
|
||||
Binary file not shown.
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
BETA
|
||||
|
||||
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
|
||||
|
||||
Users can use this instead of sending their SpendLogs to their Postgres database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import SpendLogsPayload
|
||||
else:
|
||||
SpendLogsPayload = Any
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
class GcsPubSubLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
topic_id: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize Google Cloud Pub/Sub publisher
|
||||
|
||||
Args:
|
||||
project_id (str): Google Cloud project ID
|
||||
topic_id (str): Pub/Sub topic ID
|
||||
credentials_path (str, optional): Path to Google Cloud credentials JSON file
|
||||
"""
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
|
||||
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
|
||||
self.path_service_account_json = credentials_path or os.getenv(
|
||||
"GCS_PATH_SERVICE_ACCOUNT"
|
||||
)
|
||||
|
||||
if not self.project_id or not self.topic_id:
|
||||
raise ValueError("Both project_id and topic_id must be provided")
|
||||
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
|
||||
|
||||
async def construct_request_headers(self) -> Dict[str, str]:
|
||||
"""Construct authorization headers using Vertex AI auth"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
(
|
||||
_auth_header,
|
||||
vertex_project,
|
||||
) = await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="pub-sub",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to GCS PubSub Topic
|
||||
|
||||
- Creates a SpendLogsPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"PubSub: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
# Backwards compatibility with old logging payload
|
||||
if litellm.gcs_pub_sub_use_v1 is True:
|
||||
spend_logs_payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(spend_logs_payload)
|
||||
else:
|
||||
# New logging payload, StandardLoggingPayload
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of messages to Pub/Sub
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"PubSub - about to flush {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
for message in self.log_queue:
|
||||
await self.publish_message(message)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
|
||||
async def publish_message(
|
||||
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Publish message to Google Cloud Pub/Sub using REST API
|
||||
|
||||
Args:
|
||||
message: Message to publish (dict or string)
|
||||
|
||||
Returns:
|
||||
dict: Published message response
|
||||
"""
|
||||
try:
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
# Prepare message data
|
||||
if isinstance(message, str):
|
||||
message_data = message
|
||||
else:
|
||||
message_data = json.dumps(message, default=str)
|
||||
|
||||
# Base64 encode the message
|
||||
import base64
|
||||
|
||||
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Construct request body
|
||||
request_body = {"messages": [{"data": encoded_message}]}
|
||||
|
||||
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url, headers=headers, json=request_body
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 202]:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
|
||||
raise Exception(f"Failed to publish message: {response.text}")
|
||||
|
||||
verbose_logger.debug("Pub/Sub response: %s", response.text)
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(e))
|
||||
return None
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class GreenscaleLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||
self.headers = {
|
||||
"api-key": self.greenscale_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
"modelId": kwargs.get("model"),
|
||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"outputTokenCount": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
}
|
||||
data["timestamp"] = datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
if type(end_time) is datetime and type(start_time) is datetime:
|
||||
data["invocationLatency"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
# Add additional metadata keys to tags
|
||||
tags = []
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
for key, value in metadata.items():
|
||||
if key.startswith("greenscale"):
|
||||
if key == "greenscale_project":
|
||||
data["project"] = value
|
||||
elif key == "greenscale_application":
|
||||
data["application"] = value
|
||||
else:
|
||||
tags.append(
|
||||
{"key": key.replace("greenscale_", ""), "value": str(value)}
|
||||
)
|
||||
|
||||
data["tags"] = tags
|
||||
|
||||
if self.greenscale_logging_url is None:
|
||||
raise Exception("Greenscale Logger Error - No logging URL found")
|
||||
|
||||
response = litellm.module_level_client.post(
|
||||
self.greenscale_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
@@ -0,0 +1,188 @@
|
||||
#### What this does ####
|
||||
# On success, logs events to Helicone
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class HeliconeLogger:
|
||||
# Class variables or attributes
|
||||
helicone_model_list = [
|
||||
"gpt",
|
||||
"claude",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command-light",
|
||||
"command-medium",
|
||||
"command-medium-beta",
|
||||
"command-xlarge-nightly",
|
||||
"command-nightly",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
self.provider_url = "https://api.openai.com/v1"
|
||||
self.key = os.getenv("HELICONE_API_KEY")
|
||||
|
||||
def claude_mapping(self, model, messages, response_obj):
|
||||
from anthropic import AI_PROMPT, HUMAN_PROMPT
|
||||
|
||||
prompt = f"{HUMAN_PROMPT}"
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||
else:
|
||||
prompt += f"{AI_PROMPT}{message['content']}"
|
||||
else:
|
||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||
prompt += f"{AI_PROMPT}"
|
||||
|
||||
choice = response_obj["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
content = []
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": tool_call["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
elif "content" in message and message["content"]:
|
||||
content = [{"type": "text", "text": message["content"]}]
|
||||
|
||||
claude_response_obj = {
|
||||
"id": response_obj["id"],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": choice["finish_reason"],
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": response_obj["usage"]["prompt_tokens"],
|
||||
"output_tokens": response_obj["usage"]["completion_tokens"],
|
||||
},
|
||||
}
|
||||
|
||||
return claude_response_obj
|
||||
|
||||
@staticmethod
|
||||
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
|
||||
"""
|
||||
Adds metadata from proxy request headers to Helicone logging if keys start with "helicone_"
|
||||
and overwrites litellm_params.metadata if already included.
|
||||
|
||||
For example if you want to add custom property to your request, send
|
||||
`headers: { ..., helicone-property-something: 1234 }` via proxy request.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
return metadata
|
||||
|
||||
if litellm_params.get("proxy_server_request") is None:
|
||||
return metadata
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
proxy_headers = (
|
||||
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
)
|
||||
|
||||
for header_key in proxy_headers:
|
||||
if header_key.startswith("helicone_"):
|
||||
metadata[header_key] = proxy_headers.get(header_key)
|
||||
|
||||
return metadata
|
||||
|
||||
def log_success(
|
||||
self, model, messages, response_obj, start_time, end_time, print_verbose, kwargs
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
print_verbose(
|
||||
f"Helicone Logging - Enters logging function for model {model}"
|
||||
)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
kwargs.get("litellm_call_id", None)
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
model = (
|
||||
model
|
||||
if any(
|
||||
accepted_model in model
|
||||
for accepted_model in self.helicone_model_list
|
||||
)
|
||||
else "gpt-3.5-turbo"
|
||||
)
|
||||
provider_request = {"model": model, "messages": messages}
|
||||
if isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
response_obj = response_obj.json()
|
||||
|
||||
if "claude" in model:
|
||||
response_obj = self.claude_mapping(
|
||||
model=model, messages=messages, response_obj=response_obj
|
||||
)
|
||||
|
||||
providerResponse = {
|
||||
"json": response_obj,
|
||||
"headers": {"openai-version": "2020-10-01"},
|
||||
"status": 200,
|
||||
}
|
||||
|
||||
# Code to be executed
|
||||
provider_url = self.provider_url
|
||||
url = "https://api.hconeai.com/oai/v1/log"
|
||||
if "claude" in model:
|
||||
url = "https://api.hconeai.com/anthropic/v1/log"
|
||||
provider_url = "https://api.anthropic.com/v1/messages"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
start_time_seconds = int(start_time.timestamp())
|
||||
start_time_milliseconds = int(
|
||||
(start_time.timestamp() - start_time_seconds) * 1000
|
||||
)
|
||||
end_time_seconds = int(end_time.timestamp())
|
||||
end_time_milliseconds = int(
|
||||
(end_time.timestamp() - end_time_seconds) * 1000
|
||||
)
|
||||
meta = {"Helicone-Auth": f"Bearer {self.key}"}
|
||||
meta.update(metadata)
|
||||
data = {
|
||||
"providerRequest": {
|
||||
"url": provider_url,
|
||||
"json": provider_request,
|
||||
"meta": meta,
|
||||
},
|
||||
"providerResponse": providerResponse,
|
||||
"timing": {
|
||||
"startTime": {
|
||||
"seconds": start_time_seconds,
|
||||
"milliseconds": start_time_milliseconds,
|
||||
},
|
||||
"endTime": {
|
||||
"seconds": end_time_seconds,
|
||||
"milliseconds": end_time_milliseconds,
|
||||
},
|
||||
}, # {"seconds": .., "milliseconds": ..}
|
||||
}
|
||||
response = litellm.module_level_client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
print_verbose("Helicone Logging - Success!")
|
||||
else:
|
||||
print_verbose(
|
||||
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
|
||||
)
|
||||
print_verbose(f"Helicone Logging - Error {response.text}")
|
||||
except Exception:
|
||||
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Humanloop integration
|
||||
|
||||
https://humanloop.com/
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from .custom_logger import CustomLogger
|
||||
|
||||
|
||||
class PromptManagementClient(TypedDict):
|
||||
prompt_id: str
|
||||
prompt_template: List[AllMessageValues]
|
||||
model: Optional[str]
|
||||
optional_params: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class HumanLoopPromptManager(DualCache):
|
||||
@property
|
||||
def integration_name(self):
|
||||
return "humanloop"
|
||||
|
||||
def _get_prompt_from_id_cache(
|
||||
self, humanloop_prompt_id: str
|
||||
) -> Optional[PromptManagementClient]:
|
||||
return cast(
|
||||
Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id)
|
||||
)
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Helper function to compile the prompt by substituting variables in the template.
|
||||
|
||||
Args:
|
||||
prompt_template: List[AllMessageValues]
|
||||
prompt_variables (dict): A dictionary of variables to substitute into the prompt template.
|
||||
|
||||
Returns:
|
||||
list: A list of dictionaries with variables substituted.
|
||||
"""
|
||||
compiled_prompts: List[AllMessageValues] = []
|
||||
|
||||
for template in prompt_template:
|
||||
tc = template.get("content")
|
||||
if tc and isinstance(tc, str):
|
||||
formatted_template = tc.replace("{{", "{").replace("}}", "}")
|
||||
compiled_content = formatted_template.format(**prompt_variables)
|
||||
template["content"] = compiled_content
|
||||
compiled_prompts.append(template)
|
||||
|
||||
return compiled_prompts
|
||||
|
||||
def _get_prompt_from_id_api(
|
||||
self, humanloop_prompt_id: str, humanloop_api_key: str
|
||||
) -> PromptManagementClient:
|
||||
client = _get_httpx_client()
|
||||
|
||||
base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id)
|
||||
|
||||
response = client.get(
|
||||
url=base_url,
|
||||
headers={
|
||||
"X-Api-Key": humanloop_api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Error getting prompt from Humanloop: {e.response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
template_message = json_response["template"]
|
||||
if isinstance(template_message, dict):
|
||||
template_messages = [template_message]
|
||||
elif isinstance(template_message, list):
|
||||
template_messages = template_message
|
||||
else:
|
||||
raise ValueError(f"Invalid template message type: {type(template_message)}")
|
||||
template_model = json_response["model"]
|
||||
optional_params = {}
|
||||
for k, v in json_response.items():
|
||||
if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS:
|
||||
optional_params[k] = v
|
||||
return PromptManagementClient(
|
||||
prompt_id=humanloop_prompt_id,
|
||||
prompt_template=cast(List[AllMessageValues], template_messages),
|
||||
model=template_model,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
def _get_prompt_from_id(
|
||||
self, humanloop_prompt_id: str, humanloop_api_key: str
|
||||
) -> PromptManagementClient:
|
||||
prompt = self._get_prompt_from_id_cache(humanloop_prompt_id)
|
||||
if prompt is None:
|
||||
prompt = self._get_prompt_from_id_api(
|
||||
humanloop_prompt_id, humanloop_api_key
|
||||
)
|
||||
self.set_cache(
|
||||
key=humanloop_prompt_id,
|
||||
value=prompt,
|
||||
ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def compile_prompt(
|
||||
self,
|
||||
prompt_template: List[AllMessageValues],
|
||||
prompt_variables: Optional[dict],
|
||||
) -> List[AllMessageValues]:
|
||||
compiled_prompt: Optional[Union[str, list]] = None
|
||||
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
compiled_prompt = self._compile_prompt_helper(
|
||||
prompt_template=prompt_template,
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
|
||||
return compiled_prompt
|
||||
|
||||
def _get_model_from_prompt(
|
||||
self, prompt_management_client: PromptManagementClient, model: str
|
||||
) -> str:
|
||||
if prompt_management_client["model"] is not None:
|
||||
return prompt_management_client["model"]
|
||||
else:
|
||||
return model.replace("{}/".format(self.integration_name), "")
|
||||
|
||||
|
||||
prompt_manager = HumanLoopPromptManager()
|
||||
|
||||
|
||||
class HumanloopLogger(CustomLogger):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
humanloop_api_key = dynamic_callback_params.get(
|
||||
"humanloop_api_key"
|
||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Humanloop integration")
|
||||
|
||||
if humanloop_api_key is None:
|
||||
return super().get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
)
|
||||
|
||||
prompt_template = prompt_manager._get_prompt_from_id(
|
||||
humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key
|
||||
)
|
||||
|
||||
updated_messages = prompt_manager.compile_prompt(
|
||||
prompt_template=prompt_template["prompt_template"],
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
|
||||
prompt_template_optional_params = prompt_template["optional_params"] or {}
|
||||
|
||||
updated_non_default_params = {
|
||||
**non_default_params,
|
||||
**prompt_template_optional_params,
|
||||
}
|
||||
|
||||
model = prompt_manager._get_model_from_prompt(
|
||||
prompt_management_client=prompt_template, model=model
|
||||
)
|
||||
|
||||
return model, updated_messages, updated_non_default_params
|
||||
202
.venv/lib/python3.10/site-packages/litellm/integrations/lago.py
Normal file
202
.venv/lib/python3.10/site-packages/litellm/integrations/lago.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# What is this?
|
||||
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class LagoLogger(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment()
|
||||
self.async_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self):
|
||||
"""
|
||||
Expects
|
||||
LAGO_API_BASE,
|
||||
LAGO_API_KEY,
|
||||
LAGO_API_EVENT_CODE,
|
||||
|
||||
Optional:
|
||||
LAGO_API_CHARGE_BY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if os.getenv("LAGO_API_KEY", None) is None:
|
||||
missing_keys.append("LAGO_API_KEY")
|
||||
|
||||
if os.getenv("LAGO_API_BASE", None) is None:
|
||||
missing_keys.append("LAGO_API_BASE")
|
||||
|
||||
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
|
||||
missing_keys.append("LAGO_API_EVENT_CODE")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||
response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
|
||||
if (
|
||||
isinstance(response_obj, litellm.ModelResponse)
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
) and hasattr(response_obj, "usage"):
|
||||
usage = {
|
||||
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||
litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
|
||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||
external_customer_id: Optional[str] = None
|
||||
|
||||
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
|
||||
os.environ["LAGO_API_CHARGE_BY"], str
|
||||
):
|
||||
if os.environ["LAGO_API_CHARGE_BY"] in [
|
||||
"end_user_id",
|
||||
"user_id",
|
||||
"team_id",
|
||||
]:
|
||||
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
|
||||
else:
|
||||
raise Exception("invalid LAGO_API_CHARGE_BY set")
|
||||
|
||||
if charge_by == "end_user_id":
|
||||
external_customer_id = end_user_id
|
||||
elif charge_by == "team_id":
|
||||
external_customer_id = team_id
|
||||
elif charge_by == "user_id":
|
||||
external_customer_id = user_id
|
||||
|
||||
if external_customer_id is None:
|
||||
raise Exception(
|
||||
"External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format(
|
||||
charge_by, user_id, end_user_id, team_id
|
||||
)
|
||||
)
|
||||
|
||||
returned_val = {
|
||||
"event": {
|
||||
"transaction_id": str(uuid.uuid4()),
|
||||
"external_subscription_id": external_customer_id,
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {"model": model, "response_cost": cost, **usage},
|
||||
}
|
||||
}
|
||||
|
||||
verbose_logger.debug(
|
||||
"\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val)
|
||||
)
|
||||
return returned_val
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None and hasattr(error_response, "text"):
|
||||
verbose_logger.debug(f"\nError Message: {error_response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug("ENTERS LAGO CALLBACK")
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
|
||||
_url
|
||||
)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try:
|
||||
response = await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
verbose_logger.debug(f"Logged Lago Object: {response.text}")
|
||||
except Exception as e:
|
||||
if response is not None and hasattr(response, "text"):
|
||||
verbose_logger.debug(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,955 @@
|
||||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import copy
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.langfuse import *
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
RerankResponse,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPromptManagementMetadata,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
else:
|
||||
DynamicLoggingCache = Any
|
||||
|
||||
|
||||
class LangFuseLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
langfuse_public_key=None,
|
||||
langfuse_secret=None,
|
||||
langfuse_host=None,
|
||||
flush_interval=1,
|
||||
):
|
||||
try:
|
||||
import langfuse
|
||||
from langfuse import Langfuse
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
# Instance variables
|
||||
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
self.langfuse_host = langfuse_host or os.getenv(
|
||||
"LANGFUSE_HOST", "https://cloud.langfuse.com"
|
||||
)
|
||||
if not (
|
||||
self.langfuse_host.startswith("http://")
|
||||
or self.langfuse_host.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.langfuse_host = "http://" + self.langfuse_host
|
||||
self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
|
||||
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
|
||||
self.langfuse_flush_interval = LangFuseLogger._get_langfuse_flush_interval(
|
||||
flush_interval
|
||||
)
|
||||
http_client = _get_httpx_client()
|
||||
self.langfuse_client = http_client.client
|
||||
|
||||
parameters = {
|
||||
"public_key": self.public_key,
|
||||
"secret_key": self.secret_key,
|
||||
"host": self.langfuse_host,
|
||||
"release": self.langfuse_release,
|
||||
"debug": self.langfuse_debug,
|
||||
"flush_interval": self.langfuse_flush_interval, # flush interval in seconds
|
||||
"httpx_client": self.langfuse_client,
|
||||
}
|
||||
self.langfuse_sdk_version: str = langfuse.version.__version__
|
||||
|
||||
if Version(self.langfuse_sdk_version) >= Version("2.6.0"):
|
||||
parameters["sdk_integration"] = "litellm"
|
||||
|
||||
self.Langfuse = Langfuse(**parameters)
|
||||
|
||||
# set the current langfuse project id in the environ
|
||||
# this is used by Alerting to link to the correct project
|
||||
try:
|
||||
project_id = self.Langfuse.client.projects.get().data[0].id
|
||||
os.environ["LANGFUSE_PROJECT_ID"] = project_id
|
||||
except Exception:
|
||||
project_id = None
|
||||
|
||||
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
|
||||
upstream_langfuse_debug = (
|
||||
str_to_bool(self.upstream_langfuse_debug)
|
||||
if self.upstream_langfuse_debug is not None
|
||||
else None
|
||||
)
|
||||
self.upstream_langfuse_secret_key = os.getenv(
|
||||
"UPSTREAM_LANGFUSE_SECRET_KEY"
|
||||
)
|
||||
self.upstream_langfuse_public_key = os.getenv(
|
||||
"UPSTREAM_LANGFUSE_PUBLIC_KEY"
|
||||
)
|
||||
self.upstream_langfuse_host = os.getenv("UPSTREAM_LANGFUSE_HOST")
|
||||
self.upstream_langfuse_release = os.getenv("UPSTREAM_LANGFUSE_RELEASE")
|
||||
self.upstream_langfuse_debug = os.getenv("UPSTREAM_LANGFUSE_DEBUG")
|
||||
self.upstream_langfuse = Langfuse(
|
||||
public_key=self.upstream_langfuse_public_key,
|
||||
secret_key=self.upstream_langfuse_secret_key,
|
||||
host=self.upstream_langfuse_host,
|
||||
release=self.upstream_langfuse_release,
|
||||
debug=(
|
||||
upstream_langfuse_debug
|
||||
if upstream_langfuse_debug is not None
|
||||
else False
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.upstream_langfuse = None
|
||||
|
||||
@staticmethod
|
||||
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
|
||||
"""
|
||||
Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_"
|
||||
and overwrites litellm_params.metadata if already included.
|
||||
|
||||
For example if you want to append your trace to an existing `trace_id` via header, send
|
||||
`headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
return metadata
|
||||
|
||||
if litellm_params.get("proxy_server_request") is None:
|
||||
return metadata
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
proxy_headers = (
|
||||
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
)
|
||||
|
||||
for metadata_param_key in proxy_headers:
|
||||
if metadata_param_key.startswith("langfuse_"):
|
||||
trace_param_key = metadata_param_key.replace("langfuse_", "", 1)
|
||||
if trace_param_key in metadata:
|
||||
verbose_logger.warning(
|
||||
f"Overwriting Langfuse `{trace_param_key}` from request header"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Found Langfuse `{trace_param_key}` in request header"
|
||||
)
|
||||
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
|
||||
|
||||
return metadata
|
||||
|
||||
def log_event_on_langfuse(
|
||||
self,
|
||||
kwargs: dict,
|
||||
response_obj: Union[
|
||||
None,
|
||||
dict,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
RerankResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
user_id: Optional[str] = None,
|
||||
level: str = "DEFAULT",
|
||||
status_message: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Logs a success or error event on Langfuse
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# set default values for input/output for langfuse logging
|
||||
input = None
|
||||
output = None
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
optional_params = copy.deepcopy(kwargs.get("optional_params", {}))
|
||||
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
|
||||
functions = optional_params.pop("functions", None)
|
||||
tools = optional_params.pop("tools", None)
|
||||
if functions is not None:
|
||||
prompt["functions"] = functions
|
||||
if tools is not None:
|
||||
prompt["tools"] = tools
|
||||
|
||||
# langfuse only accepts str, int, bool, float for logging
|
||||
for param, value in optional_params.items():
|
||||
if not isinstance(value, (str, int, bool, float)):
|
||||
try:
|
||||
optional_params[param] = str(value)
|
||||
except Exception:
|
||||
# if casting value to str fails don't block logging
|
||||
pass
|
||||
|
||||
input, output = self._get_langfuse_input_output_content(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
prompt=prompt,
|
||||
level=level,
|
||||
status_message=status_message,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
|
||||
)
|
||||
trace_id = None
|
||||
generation_id = None
|
||||
if self._is_langfuse_v2():
|
||||
trace_id, generation_id = self._log_langfuse_v2(
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
litellm_params=litellm_params,
|
||||
output=output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
kwargs=kwargs,
|
||||
optional_params=optional_params,
|
||||
input=input,
|
||||
response_obj=response_obj,
|
||||
level=level,
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
elif response_obj is not None:
|
||||
self._log_langfuse_v1(
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
output=output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
kwargs=kwargs,
|
||||
optional_params=optional_params,
|
||||
input=input,
|
||||
response_obj=response_obj,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
verbose_logger.info("Langfuse Layer Logging - logging success")
|
||||
|
||||
return {"trace_id": trace_id, "generation_id": generation_id}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Langfuse Layer Error(): Exception occured - {}".format(str(e))
|
||||
)
|
||||
return {"trace_id": None, "generation_id": None}
|
||||
|
||||
def _get_langfuse_input_output_content(
|
||||
self,
|
||||
kwargs: dict,
|
||||
response_obj: Union[
|
||||
None,
|
||||
dict,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
RerankResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
prompt: dict,
|
||||
level: str,
|
||||
status_message: Optional[str],
|
||||
) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]:
|
||||
"""
|
||||
Get the input and output content for Langfuse logging
|
||||
|
||||
Args:
|
||||
kwargs: The keyword arguments passed to the function
|
||||
response_obj: The response object returned by the function
|
||||
prompt: The prompt used to generate the response
|
||||
level: The level of the log message
|
||||
status_message: The status message of the log message
|
||||
|
||||
Returns:
|
||||
input: The input content for Langfuse logging
|
||||
output: The output content for Langfuse logging
|
||||
"""
|
||||
input = None
|
||||
output: Optional[Union[str, dict, List[Any]]] = None
|
||||
if (
|
||||
level == "ERROR"
|
||||
and status_message is not None
|
||||
and isinstance(status_message, str)
|
||||
):
|
||||
input = prompt
|
||||
output = status_message
|
||||
elif response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
input = prompt
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
input = prompt
|
||||
output = self._get_chat_content_for_langfuse(response_obj)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.HttpxBinaryResponseContent
|
||||
):
|
||||
input = prompt
|
||||
output = "speech-output"
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = self._get_text_completion_content_for_langfuse(response_obj)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("data", None)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TranscriptionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("text", None)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.RerankResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.results
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "_arealtime"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, list)
|
||||
):
|
||||
input = kwargs.get("input")
|
||||
output = response_obj
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, dict)
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("response", "")
|
||||
return input, output
|
||||
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, user_id
|
||||
):
|
||||
"""
|
||||
Langfuse SDK uses a background thread to log events
|
||||
|
||||
This approach does not impact latency and runs in the background
|
||||
"""
|
||||
|
||||
def _is_langfuse_v2(self):
|
||||
import langfuse
|
||||
|
||||
return Version(langfuse.version.__version__) >= Version("2.0.0")
|
||||
|
||||
def _log_langfuse_v1(
|
||||
self,
|
||||
user_id,
|
||||
metadata,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
kwargs,
|
||||
optional_params,
|
||||
input,
|
||||
response_obj,
|
||||
):
|
||||
from langfuse.model import CreateGeneration, CreateTrace # type: ignore
|
||||
|
||||
verbose_logger.warning(
|
||||
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
||||
)
|
||||
|
||||
trace = self.Langfuse.trace( # type: ignore
|
||||
CreateTrace( # type: ignore
|
||||
name=metadata.get("generation_name", "litellm-completion"),
|
||||
input=input,
|
||||
output=output,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
trace.generation(
|
||||
CreateGeneration(
|
||||
name=metadata.get("generation_name", "litellm-completion"),
|
||||
startTime=start_time,
|
||||
endTime=end_time,
|
||||
model=kwargs["model"],
|
||||
modelParameters=optional_params,
|
||||
prompt=input,
|
||||
completion=output,
|
||||
usage={
|
||||
"prompt_tokens": response_obj.usage.prompt_tokens,
|
||||
"completion_tokens": response_obj.usage.completion_tokens,
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
def _log_langfuse_v2( # noqa: PLR0915
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
metadata: dict,
|
||||
litellm_params: dict,
|
||||
output: Optional[Union[str, dict, list]],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime],
|
||||
kwargs: dict,
|
||||
optional_params: dict,
|
||||
input: Optional[dict],
|
||||
response_obj,
|
||||
level: str,
|
||||
litellm_call_id: Optional[str],
|
||||
) -> tuple:
|
||||
verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
|
||||
|
||||
try:
|
||||
metadata = metadata or {}
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = cast(
|
||||
Optional[StandardLoggingPayload],
|
||||
kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
tags = (
|
||||
self._get_langfuse_tags(standard_logging_object=standard_logging_object)
|
||||
if self._supports_tags()
|
||||
else []
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
end_user_id = None
|
||||
prompt_management_metadata: Optional[
|
||||
StandardLoggingPromptManagementMetadata
|
||||
] = None
|
||||
else:
|
||||
end_user_id = standard_logging_object["metadata"].get(
|
||||
"user_api_key_end_user_id", None
|
||||
)
|
||||
|
||||
prompt_management_metadata = cast(
|
||||
Optional[StandardLoggingPromptManagementMetadata],
|
||||
standard_logging_object["metadata"].get(
|
||||
"prompt_management_metadata", None
|
||||
),
|
||||
)
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata: Dict[str, Any] = {}
|
||||
if prompt_management_metadata is not None:
|
||||
clean_metadata[
|
||||
"prompt_management_metadata"
|
||||
] = prompt_management_metadata
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and key in litellm.langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"headers",
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
# Add default langfuse tags
|
||||
tags = self.add_default_langfuse_tags(
|
||||
tags=tags, kwargs=kwargs, metadata=metadata
|
||||
)
|
||||
|
||||
session_id = clean_metadata.pop("session_id", None)
|
||||
trace_name = cast(Optional[str], clean_metadata.pop("trace_name", None))
|
||||
trace_id = clean_metadata.pop("trace_id", litellm_call_id)
|
||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||
update_trace_keys = cast(list, clean_metadata.pop("update_trace_keys", []))
|
||||
debug = clean_metadata.pop("debug_langfuse", None)
|
||||
mask_input = clean_metadata.pop("mask_input", False)
|
||||
mask_output = clean_metadata.pop("mask_output", False)
|
||||
|
||||
clean_metadata = redact_user_api_key_info(metadata=clean_metadata)
|
||||
|
||||
if trace_name is None and existing_trace_id is None:
|
||||
# just log `litellm-{call_type}` as the trace name
|
||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
||||
if existing_trace_id is not None:
|
||||
trace_params: Dict[str, Any] = {"id": existing_trace_id}
|
||||
|
||||
# Update the following keys for this trace
|
||||
for metadata_param_key in update_trace_keys:
|
||||
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||
if trace_param_key not in trace_params:
|
||||
updated_trace_value = clean_metadata.pop(
|
||||
metadata_param_key, None
|
||||
)
|
||||
if updated_trace_value is not None:
|
||||
trace_params[trace_param_key] = updated_trace_value
|
||||
|
||||
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
clean_metadata.pop(key, None)
|
||||
|
||||
# Special keys that are found in the function arguments and not the metadata
|
||||
if "input" in update_trace_keys:
|
||||
trace_params["input"] = (
|
||||
input if not mask_input else "redacted-by-litellm"
|
||||
)
|
||||
if "output" in update_trace_keys:
|
||||
trace_params["output"] = (
|
||||
output if not mask_output else "redacted-by-litellm"
|
||||
)
|
||||
else: # don't overwrite an existing trace
|
||||
trace_params = {
|
||||
"id": trace_id,
|
||||
"name": trace_name,
|
||||
"session_id": session_id,
|
||||
"input": input if not mask_input else "redacted-by-litellm",
|
||||
"version": clean_metadata.pop(
|
||||
"trace_version", clean_metadata.get("version", None)
|
||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
"user_id": end_user_id,
|
||||
}
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||
key, None
|
||||
)
|
||||
|
||||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = (
|
||||
output if not mask_output else "redacted-by-litellm"
|
||||
)
|
||||
|
||||
if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||
if "metadata" in trace_params:
|
||||
# log the raw_metadata in the trace
|
||||
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
||||
else:
|
||||
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
verbose_logger.debug(f"trace: {cost}")
|
||||
|
||||
clean_metadata["litellm_response_cost"] = cost
|
||||
if standard_logging_object is not None:
|
||||
clean_metadata["hidden_params"] = standard_logging_object[
|
||||
"hidden_params"
|
||||
]
|
||||
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and "proxy_base_url" in litellm.langfuse_default_tags
|
||||
):
|
||||
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
|
||||
if proxy_base_url is not None:
|
||||
tags.append(f"proxy_base_url:{proxy_base_url}")
|
||||
|
||||
api_base = litellm_params.get("api_base", None)
|
||||
if api_base:
|
||||
clean_metadata["api_base"] = api_base
|
||||
|
||||
vertex_location = kwargs.get("vertex_location", None)
|
||||
if vertex_location:
|
||||
clean_metadata["vertex_location"] = vertex_location
|
||||
|
||||
aws_region_name = kwargs.get("aws_region_name", None)
|
||||
if aws_region_name:
|
||||
clean_metadata["aws_region_name"] = aws_region_name
|
||||
|
||||
if self._supports_tags():
|
||||
if "cache_hit" in kwargs:
|
||||
if kwargs["cache_hit"] is None:
|
||||
kwargs["cache_hit"] = False
|
||||
clean_metadata["cache_hit"] = kwargs["cache_hit"]
|
||||
if existing_trace_id is None:
|
||||
trace_params.update({"tags": tags})
|
||||
|
||||
proxy_server_request = litellm_params.get("proxy_server_request", None)
|
||||
if proxy_server_request:
|
||||
proxy_server_request.get("method", None)
|
||||
proxy_server_request.get("url", None)
|
||||
headers = proxy_server_request.get("headers", None)
|
||||
clean_headers = {}
|
||||
if headers:
|
||||
for key, value in headers.items():
|
||||
# these headers can leak our API keys and/or JWT tokens
|
||||
if key.lower() not in ["authorization", "cookie", "referer"]:
|
||||
clean_headers[key] = value
|
||||
|
||||
# clean_metadata["request"] = {
|
||||
# "method": method,
|
||||
# "url": url,
|
||||
# "headers": clean_headers,
|
||||
# }
|
||||
trace = self.Langfuse.trace(**trace_params)
|
||||
|
||||
# Log provider specific information as a span
|
||||
log_provider_specific_information_as_span(trace, clean_metadata)
|
||||
|
||||
generation_id = None
|
||||
usage = None
|
||||
if response_obj is not None:
|
||||
if (
|
||||
hasattr(response_obj, "id")
|
||||
and response_obj.get("id", None) is not None
|
||||
):
|
||||
generation_id = litellm.utils.get_logging_id(
|
||||
start_time, response_obj
|
||||
)
|
||||
_usage_obj = getattr(response_obj, "usage", None)
|
||||
|
||||
if _usage_obj:
|
||||
usage = {
|
||||
"prompt_tokens": _usage_obj.prompt_tokens,
|
||||
"completion_tokens": _usage_obj.completion_tokens,
|
||||
"total_cost": cost if self._supports_costs() else None,
|
||||
}
|
||||
generation_name = clean_metadata.pop("generation_name", None)
|
||||
if generation_name is None:
|
||||
# if `generation_name` is None, use sensible default values
|
||||
# If using litellm proxy user `key_alias` if not None
|
||||
# If `key_alias` is None, just log `litellm-{call_type}` as the generation name
|
||||
_user_api_key_alias = cast(
|
||||
Optional[str], clean_metadata.get("user_api_key_alias", None)
|
||||
)
|
||||
generation_name = (
|
||||
f"litellm-{cast(str, kwargs.get('call_type', 'completion'))}"
|
||||
)
|
||||
if _user_api_key_alias is not None:
|
||||
generation_name = f"litellm:{_user_api_key_alias}"
|
||||
|
||||
if response_obj is not None:
|
||||
system_fingerprint = getattr(response_obj, "system_fingerprint", None)
|
||||
else:
|
||||
system_fingerprint = None
|
||||
|
||||
if system_fingerprint is not None:
|
||||
optional_params["system_fingerprint"] = system_fingerprint
|
||||
|
||||
generation_params = {
|
||||
"name": generation_name,
|
||||
"id": clean_metadata.pop("generation_id", generation_id),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"model": kwargs["model"],
|
||||
"model_parameters": optional_params,
|
||||
"input": input if not mask_input else "redacted-by-litellm",
|
||||
"output": output if not mask_output else "redacted-by-litellm",
|
||||
"usage": usage,
|
||||
"metadata": log_requester_metadata(clean_metadata),
|
||||
"level": level,
|
||||
"version": clean_metadata.pop("version", None),
|
||||
}
|
||||
|
||||
parent_observation_id = metadata.get("parent_observation_id", None)
|
||||
if parent_observation_id is not None:
|
||||
generation_params["parent_observation_id"] = parent_observation_id
|
||||
|
||||
if self._supports_prompt():
|
||||
generation_params = _add_prompt_to_generation_params(
|
||||
generation_params=generation_params,
|
||||
clean_metadata=clean_metadata,
|
||||
prompt_management_metadata=prompt_management_metadata,
|
||||
langfuse_client=self.Langfuse,
|
||||
)
|
||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||
generation_params["status_message"] = output
|
||||
|
||||
if self._supports_completion_start_time():
|
||||
generation_params["completion_start_time"] = kwargs.get(
|
||||
"completion_start_time", None
|
||||
)
|
||||
|
||||
generation_client = trace.generation(**generation_params)
|
||||
|
||||
return generation_client.trace_id, generation_id
|
||||
except Exception:
|
||||
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_content_for_langfuse(
|
||||
response_obj: ModelResponse,
|
||||
):
|
||||
"""
|
||||
Get the chat content for Langfuse logging
|
||||
"""
|
||||
if response_obj.choices and len(response_obj.choices) > 0:
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
return output
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_text_completion_content_for_langfuse(
|
||||
response_obj: TextCompletionResponse,
|
||||
):
|
||||
"""
|
||||
Get the text completion content for Langfuse logging
|
||||
"""
|
||||
if response_obj.choices and len(response_obj.choices) > 0:
|
||||
return response_obj.choices[0].text
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_langfuse_tags(
|
||||
standard_logging_object: Optional[StandardLoggingPayload],
|
||||
) -> List[str]:
|
||||
if standard_logging_object is None:
|
||||
return []
|
||||
return standard_logging_object.get("request_tags", []) or []
|
||||
|
||||
def add_default_langfuse_tags(self, tags, kwargs, metadata):
|
||||
"""
|
||||
Helper function to add litellm default langfuse tags
|
||||
|
||||
- Special LiteLLM tags:
|
||||
- cache_hit
|
||||
- cache_key
|
||||
|
||||
"""
|
||||
if litellm.langfuse_default_tags is not None and isinstance(
|
||||
litellm.langfuse_default_tags, list
|
||||
):
|
||||
if "cache_hit" in litellm.langfuse_default_tags:
|
||||
_cache_hit_value = kwargs.get("cache_hit", False)
|
||||
tags.append(f"cache_hit:{_cache_hit_value}")
|
||||
if "cache_key" in litellm.langfuse_default_tags:
|
||||
_hidden_params = metadata.get("hidden_params", {}) or {}
|
||||
_cache_key = _hidden_params.get("cache_key", None)
|
||||
if _cache_key is None and litellm.cache is not None:
|
||||
# fallback to using "preset_cache_key"
|
||||
_preset_cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
_cache_key = _preset_cache_key
|
||||
tags.append(f"cache_key:{_cache_key}")
|
||||
return tags
|
||||
|
||||
def _supports_tags(self):
|
||||
"""Check if current langfuse version supports tags"""
|
||||
return Version(self.langfuse_sdk_version) >= Version("2.6.3")
|
||||
|
||||
def _supports_prompt(self):
|
||||
"""Check if current langfuse version supports prompt"""
|
||||
return Version(self.langfuse_sdk_version) >= Version("2.7.3")
|
||||
|
||||
def _supports_costs(self):
|
||||
"""Check if current langfuse version supports costs"""
|
||||
return Version(self.langfuse_sdk_version) >= Version("2.7.3")
|
||||
|
||||
def _supports_completion_start_time(self):
|
||||
"""Check if current langfuse version supports completion start time"""
|
||||
return Version(self.langfuse_sdk_version) >= Version("2.7.3")
|
||||
|
||||
@staticmethod
|
||||
def _get_langfuse_flush_interval(flush_interval: int) -> int:
|
||||
"""
|
||||
Get the langfuse flush interval to initialize the Langfuse client
|
||||
|
||||
Reads `LANGFUSE_FLUSH_INTERVAL` from the environment variable.
|
||||
If not set, uses the flush interval passed in as an argument.
|
||||
|
||||
Args:
|
||||
flush_interval: The flush interval to use if LANGFUSE_FLUSH_INTERVAL is not set
|
||||
|
||||
Returns:
|
||||
[int] The flush interval to use to initialize the Langfuse client
|
||||
"""
|
||||
return int(os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval)
|
||||
|
||||
|
||||
def _add_prompt_to_generation_params(
|
||||
generation_params: dict,
|
||||
clean_metadata: dict,
|
||||
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata],
|
||||
langfuse_client: Any,
|
||||
) -> dict:
|
||||
from langfuse import Langfuse
|
||||
from langfuse.model import (
|
||||
ChatPromptClient,
|
||||
Prompt_Chat,
|
||||
Prompt_Text,
|
||||
TextPromptClient,
|
||||
)
|
||||
|
||||
langfuse_client = cast(Langfuse, langfuse_client)
|
||||
|
||||
user_prompt = clean_metadata.pop("prompt", None)
|
||||
if user_prompt is None and prompt_management_metadata is None:
|
||||
pass
|
||||
elif isinstance(user_prompt, dict):
|
||||
if user_prompt.get("type", "") == "chat":
|
||||
_prompt_chat = Prompt_Chat(**user_prompt)
|
||||
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_chat)
|
||||
elif user_prompt.get("type", "") == "text":
|
||||
_prompt_text = Prompt_Text(**user_prompt)
|
||||
generation_params["prompt"] = TextPromptClient(prompt=_prompt_text)
|
||||
elif "version" in user_prompt and "prompt" in user_prompt:
|
||||
# prompts
|
||||
if isinstance(user_prompt["prompt"], str):
|
||||
prompt_text_params = getattr(
|
||||
Prompt_Text, "model_fields", Prompt_Text.__fields__
|
||||
)
|
||||
_data = {
|
||||
"name": user_prompt["name"],
|
||||
"prompt": user_prompt["prompt"],
|
||||
"version": user_prompt["version"],
|
||||
"config": user_prompt.get("config", None),
|
||||
}
|
||||
if "labels" in prompt_text_params and "tags" in prompt_text_params:
|
||||
_data["labels"] = user_prompt.get("labels", []) or []
|
||||
_data["tags"] = user_prompt.get("tags", []) or []
|
||||
_prompt_obj = Prompt_Text(**_data) # type: ignore
|
||||
generation_params["prompt"] = TextPromptClient(prompt=_prompt_obj)
|
||||
|
||||
elif isinstance(user_prompt["prompt"], list):
|
||||
prompt_chat_params = getattr(
|
||||
Prompt_Chat, "model_fields", Prompt_Chat.__fields__
|
||||
)
|
||||
_data = {
|
||||
"name": user_prompt["name"],
|
||||
"prompt": user_prompt["prompt"],
|
||||
"version": user_prompt["version"],
|
||||
"config": user_prompt.get("config", None),
|
||||
}
|
||||
if "labels" in prompt_chat_params and "tags" in prompt_chat_params:
|
||||
_data["labels"] = user_prompt.get("labels", []) or []
|
||||
_data["tags"] = user_prompt.get("tags", []) or []
|
||||
|
||||
_prompt_obj = Prompt_Chat(**_data) # type: ignore
|
||||
|
||||
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_obj)
|
||||
else:
|
||||
verbose_logger.error(
|
||||
"[Non-blocking] Langfuse Logger: Invalid prompt format"
|
||||
)
|
||||
else:
|
||||
verbose_logger.error(
|
||||
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
|
||||
)
|
||||
elif (
|
||||
prompt_management_metadata is not None
|
||||
and prompt_management_metadata["prompt_integration"] == "langfuse"
|
||||
):
|
||||
try:
|
||||
generation_params["prompt"] = langfuse_client.get_prompt(
|
||||
prompt_management_metadata["prompt_id"]
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
else:
|
||||
generation_params["prompt"] = user_prompt
|
||||
|
||||
return generation_params
|
||||
|
||||
|
||||
def log_provider_specific_information_as_span(
|
||||
trace,
|
||||
clean_metadata,
|
||||
):
|
||||
"""
|
||||
Logs provider-specific information as spans.
|
||||
|
||||
Parameters:
|
||||
trace: The tracing object used to log spans.
|
||||
clean_metadata: A dictionary containing metadata to be logged.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
_hidden_params = clean_metadata.get("hidden_params", None)
|
||||
if _hidden_params is None:
|
||||
return
|
||||
|
||||
vertex_ai_grounding_metadata = _hidden_params.get(
|
||||
"vertex_ai_grounding_metadata", None
|
||||
)
|
||||
|
||||
if vertex_ai_grounding_metadata is not None:
|
||||
if isinstance(vertex_ai_grounding_metadata, list):
|
||||
for elem in vertex_ai_grounding_metadata:
|
||||
if isinstance(elem, dict):
|
||||
for key, value in elem.items():
|
||||
trace.span(
|
||||
name=key,
|
||||
input=value,
|
||||
)
|
||||
else:
|
||||
trace.span(
|
||||
name="vertex_ai_grounding_metadata",
|
||||
input=elem,
|
||||
)
|
||||
else:
|
||||
trace.span(
|
||||
name="vertex_ai_grounding_metadata",
|
||||
input=vertex_ai_grounding_metadata,
|
||||
)
|
||||
|
||||
|
||||
def log_requester_metadata(clean_metadata: dict):
|
||||
returned_metadata = {}
|
||||
requester_metadata = clean_metadata.get("requester_metadata") or {}
|
||||
for k, v in clean_metadata.items():
|
||||
if k not in requester_metadata:
|
||||
returned_metadata[k] = v
|
||||
|
||||
returned_metadata.update({"requester_metadata": requester_metadata})
|
||||
|
||||
return returned_metadata
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
This file contains the LangFuseHandler class
|
||||
|
||||
Used to get the LangFuseLogger for a given request
|
||||
|
||||
Handles Key/Team Based Langfuse Logging
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
|
||||
|
||||
from .langfuse import LangFuseLogger, LangfuseLoggingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
else:
|
||||
DynamicLoggingCache = Any
|
||||
|
||||
|
||||
class LangFuseHandler:
|
||||
@staticmethod
|
||||
def get_langfuse_logger_for_request(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
globalLangfuseLogger: Optional[LangFuseLogger] = None,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
This function is used to get the LangFuseLogger for a given request
|
||||
|
||||
1. If dynamic credentials are passed
|
||||
- check if a LangFuseLogger is cached for the dynamic credentials
|
||||
- if cached LangFuseLogger is not found, create a new LangFuseLogger and cache it
|
||||
|
||||
2. If dynamic credentials are not passed return the globalLangfuseLogger
|
||||
|
||||
"""
|
||||
temp_langfuse_logger: Optional[LangFuseLogger] = globalLangfuseLogger
|
||||
if (
|
||||
LangFuseHandler._dynamic_langfuse_credentials_are_passed(
|
||||
standard_callback_dynamic_params
|
||||
)
|
||||
is False
|
||||
):
|
||||
return LangFuseHandler._return_global_langfuse_logger(
|
||||
globalLangfuseLogger=globalLangfuseLogger,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
|
||||
# get langfuse logging config to use for this request, based on standard_callback_dynamic_params
|
||||
_credentials = LangFuseHandler.get_dynamic_langfuse_logging_config(
|
||||
globalLangfuseLogger=globalLangfuseLogger,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
credentials_dict = dict(_credentials)
|
||||
|
||||
# check if langfuse logger is already cached
|
||||
temp_langfuse_logger = in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials_dict, service_name="langfuse"
|
||||
)
|
||||
|
||||
# if not cached, create a new langfuse logger and cache it
|
||||
if temp_langfuse_logger is None:
|
||||
temp_langfuse_logger = (
|
||||
LangFuseHandler._create_langfuse_logger_from_credentials(
|
||||
credentials=credentials_dict,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
)
|
||||
|
||||
return temp_langfuse_logger
|
||||
|
||||
@staticmethod
|
||||
def _return_global_langfuse_logger(
|
||||
globalLangfuseLogger: Optional[LangFuseLogger],
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
Returns the Global LangfuseLogger set on litellm
|
||||
|
||||
(this is the default langfuse logger - used when no dynamic credentials are passed)
|
||||
|
||||
If no Global LangfuseLogger is set, it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
|
||||
This function is used to return the globalLangfuseLogger if it exists, otherwise it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
|
||||
"""
|
||||
if globalLangfuseLogger is not None:
|
||||
return globalLangfuseLogger
|
||||
|
||||
credentials_dict: Dict[
|
||||
str, Any
|
||||
] = (
|
||||
{}
|
||||
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
|
||||
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials_dict,
|
||||
service_name="langfuse",
|
||||
)
|
||||
if globalLangfuseLogger is None:
|
||||
globalLangfuseLogger = (
|
||||
LangFuseHandler._create_langfuse_logger_from_credentials(
|
||||
credentials=credentials_dict,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
)
|
||||
return globalLangfuseLogger
|
||||
|
||||
@staticmethod
|
||||
def _create_langfuse_logger_from_credentials(
|
||||
credentials: Dict,
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
This function is used to
|
||||
1. create a LangFuseLogger from the credentials
|
||||
2. cache the LangFuseLogger to prevent re-creating it for the same credentials
|
||||
"""
|
||||
|
||||
langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=credentials.get("langfuse_public_key"),
|
||||
langfuse_secret=credentials.get("langfuse_secret"),
|
||||
langfuse_host=credentials.get("langfuse_host"),
|
||||
)
|
||||
in_memory_dynamic_logger_cache.set_cache(
|
||||
credentials=credentials,
|
||||
service_name="langfuse",
|
||||
logging_obj=langfuse_logger,
|
||||
)
|
||||
return langfuse_logger
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_langfuse_logging_config(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
globalLangfuseLogger: Optional[LangFuseLogger] = None,
|
||||
) -> LangfuseLoggingConfig:
|
||||
"""
|
||||
This function is used to get the Langfuse logging config to use for a given request.
|
||||
|
||||
It checks if the dynamic parameters are provided in the standard_callback_dynamic_params and uses them to get the Langfuse logging config.
|
||||
|
||||
If no dynamic parameters are provided, it uses the `globalLangfuseLogger` values
|
||||
"""
|
||||
# only use dynamic params if langfuse credentials are passed dynamically
|
||||
return LangfuseLoggingConfig(
|
||||
langfuse_secret=standard_callback_dynamic_params.get("langfuse_secret")
|
||||
or standard_callback_dynamic_params.get("langfuse_secret_key"),
|
||||
langfuse_public_key=standard_callback_dynamic_params.get(
|
||||
"langfuse_public_key"
|
||||
),
|
||||
langfuse_host=standard_callback_dynamic_params.get("langfuse_host"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dynamic_langfuse_credentials_are_passed(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
This function is used to check if the dynamic langfuse credentials are passed in standard_callback_dynamic_params
|
||||
|
||||
Returns:
|
||||
bool: True if the dynamic langfuse credentials are passed, False otherwise
|
||||
"""
|
||||
|
||||
if (
|
||||
standard_callback_dynamic_params.get("langfuse_host") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_secret") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_secret_key") is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Call Hook for LiteLLM Proxy which allows Langfuse prompt management.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prompt_management_base import PromptManagementClient
|
||||
from litellm.litellm_core_utils.asyncify import run_async_function
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionSystemMessage
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
|
||||
DynamicLoggingCache,
|
||||
)
|
||||
from ..prompt_management_base import PromptManagementBase
|
||||
from .langfuse import LangFuseLogger
|
||||
from .langfuse_handler import LangFuseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langfuse import Langfuse
|
||||
from langfuse.client import ChatPromptClient, TextPromptClient
|
||||
|
||||
LangfuseClass: TypeAlias = Langfuse
|
||||
|
||||
PROMPT_CLIENT = Union[TextPromptClient, ChatPromptClient]
|
||||
else:
|
||||
PROMPT_CLIENT = Any
|
||||
LangfuseClass = Any
|
||||
|
||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def langfuse_client_init(
|
||||
langfuse_public_key=None,
|
||||
langfuse_secret=None,
|
||||
langfuse_secret_key=None,
|
||||
langfuse_host=None,
|
||||
flush_interval=1,
|
||||
) -> LangfuseClass:
|
||||
"""
|
||||
Initialize Langfuse client with caching to prevent multiple initializations.
|
||||
|
||||
Args:
|
||||
langfuse_public_key (str, optional): Public key for Langfuse. Defaults to None.
|
||||
langfuse_secret (str, optional): Secret key for Langfuse. Defaults to None.
|
||||
langfuse_host (str, optional): Host URL for Langfuse. Defaults to None.
|
||||
flush_interval (int, optional): Flush interval in seconds. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Langfuse: Initialized Langfuse client instance
|
||||
|
||||
Raises:
|
||||
Exception: If langfuse package is not installed
|
||||
"""
|
||||
try:
|
||||
import langfuse
|
||||
from langfuse import Langfuse
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n\033[0m"
|
||||
)
|
||||
|
||||
# Instance variables
|
||||
|
||||
secret_key = (
|
||||
langfuse_secret or langfuse_secret_key or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
)
|
||||
public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_host = langfuse_host or os.getenv(
|
||||
"LANGFUSE_HOST", "https://cloud.langfuse.com"
|
||||
)
|
||||
|
||||
if not (
|
||||
langfuse_host.startswith("http://") or langfuse_host.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
langfuse_host = "http://" + langfuse_host
|
||||
|
||||
langfuse_release = os.getenv("LANGFUSE_RELEASE")
|
||||
langfuse_debug = os.getenv("LANGFUSE_DEBUG")
|
||||
|
||||
parameters = {
|
||||
"public_key": public_key,
|
||||
"secret_key": secret_key,
|
||||
"host": langfuse_host,
|
||||
"release": langfuse_release,
|
||||
"debug": langfuse_debug,
|
||||
"flush_interval": LangFuseLogger._get_langfuse_flush_interval(
|
||||
flush_interval
|
||||
), # flush interval in seconds
|
||||
}
|
||||
|
||||
if Version(langfuse.version.__version__) >= Version("2.6.0"):
|
||||
parameters["sdk_integration"] = "litellm"
|
||||
|
||||
client = Langfuse(**parameters)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
langfuse_public_key=None,
|
||||
langfuse_secret=None,
|
||||
langfuse_host=None,
|
||||
flush_interval=1,
|
||||
):
|
||||
import langfuse
|
||||
|
||||
self.langfuse_sdk_version = langfuse.version.__version__
|
||||
self.Langfuse = langfuse_client_init(
|
||||
langfuse_public_key=langfuse_public_key,
|
||||
langfuse_secret=langfuse_secret,
|
||||
langfuse_host=langfuse_host,
|
||||
flush_interval=flush_interval,
|
||||
)
|
||||
|
||||
@property
|
||||
def integration_name(self):
|
||||
return "langfuse"
|
||||
|
||||
def _get_prompt_from_id(
|
||||
self, langfuse_prompt_id: str, langfuse_client: LangfuseClass
|
||||
) -> PROMPT_CLIENT:
|
||||
return langfuse_client.get_prompt(langfuse_prompt_id)
|
||||
|
||||
def _compile_prompt(
|
||||
self,
|
||||
langfuse_prompt_client: PROMPT_CLIENT,
|
||||
langfuse_prompt_variables: Optional[dict],
|
||||
call_type: Union[Literal["completion"], Literal["text_completion"]],
|
||||
) -> List[AllMessageValues]:
|
||||
compiled_prompt: Optional[Union[str, list]] = None
|
||||
|
||||
if langfuse_prompt_variables is None:
|
||||
langfuse_prompt_variables = {}
|
||||
|
||||
compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables)
|
||||
|
||||
if isinstance(compiled_prompt, str):
|
||||
compiled_prompt = [
|
||||
ChatCompletionSystemMessage(role="system", content=compiled_prompt)
|
||||
]
|
||||
else:
|
||||
compiled_prompt = cast(List[AllMessageValues], compiled_prompt)
|
||||
|
||||
return compiled_prompt
|
||||
|
||||
def _get_optional_params_from_langfuse(
|
||||
self, langfuse_prompt_client: PROMPT_CLIENT
|
||||
) -> dict:
|
||||
config = langfuse_prompt_client.config
|
||||
optional_params = {}
|
||||
for k, v in config.items():
|
||||
if k != "model":
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
return self.get_chat_completion_prompt(
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
)
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: str,
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
langfuse_client = langfuse_client_init(
|
||||
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
|
||||
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
|
||||
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
|
||||
langfuse_host=dynamic_callback_params.get("langfuse_host"),
|
||||
)
|
||||
langfuse_prompt_client = self._get_prompt_from_id(
|
||||
langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client
|
||||
)
|
||||
return langfuse_prompt_client is not None
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> PromptManagementClient:
|
||||
langfuse_client = langfuse_client_init(
|
||||
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
|
||||
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
|
||||
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
|
||||
langfuse_host=dynamic_callback_params.get("langfuse_host"),
|
||||
)
|
||||
langfuse_prompt_client = self._get_prompt_from_id(
|
||||
langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client
|
||||
)
|
||||
|
||||
## SET PROMPT
|
||||
compiled_prompt = self._compile_prompt(
|
||||
langfuse_prompt_client=langfuse_prompt_client,
|
||||
langfuse_prompt_variables=prompt_variables,
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
template_model = langfuse_prompt_client.config.get("model")
|
||||
|
||||
template_optional_params = self._get_optional_params_from_langfuse(
|
||||
langfuse_prompt_client
|
||||
)
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=compiled_prompt,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=template_optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
return run_async_function(
|
||||
self.async_log_success_event, kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||
globalLangfuseLogger=self,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
langfuse_logger_to_use.log_event_on_langfuse(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||
globalLangfuseLogger=self,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
standard_logging_object = cast(
|
||||
Optional[StandardLoggingPayload],
|
||||
kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
langfuse_logger_to_use.log_event_on_langfuse(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
user_id=kwargs.get("user", None),
|
||||
status_message=standard_logging_object["error_str"],
|
||||
level="ERROR",
|
||||
kwargs=kwargs,
|
||||
)
|
||||
@@ -0,0 +1,498 @@
|
||||
#### What this does ####
|
||||
# On success, logs events to Langsmith
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
import types
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.langsmith import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
types.FunctionType,
|
||||
types.GeneratorType,
|
||||
BaseModel,
|
||||
)
|
||||
return not isinstance(value, non_serializable_types)
|
||||
|
||||
|
||||
class LangsmithLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
langsmith_api_key: Optional[str] = None,
|
||||
langsmith_project: Optional[str] = None,
|
||||
langsmith_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.default_credentials = self.get_credentials_from_env(
|
||||
langsmith_api_key=langsmith_api_key,
|
||||
langsmith_project=langsmith_project,
|
||||
langsmith_base_url=langsmith_base_url,
|
||||
)
|
||||
self.sampling_rate: float = (
|
||||
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
self.langsmith_default_run_name = os.getenv(
|
||||
"LANGSMITH_DEFAULT_RUN_NAME", "LLMRun"
|
||||
)
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_batch_size = (
|
||||
os.getenv("LANGSMITH_BATCH_SIZE", None) or litellm.langsmith_batch_size
|
||||
)
|
||||
if _batch_size:
|
||||
self.batch_size = int(_batch_size)
|
||||
self.log_queue: List[LangsmithQueueObject] = []
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
|
||||
def get_credentials_from_env(
|
||||
self,
|
||||
langsmith_api_key: Optional[str] = None,
|
||||
langsmith_project: Optional[str] = None,
|
||||
langsmith_base_url: Optional[str] = None,
|
||||
) -> LangsmithCredentialsObject:
|
||||
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception(
|
||||
"Invalid Langsmith API Key given. _credentials_api_key=None."
|
||||
)
|
||||
_credentials_project = (
|
||||
langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion"
|
||||
)
|
||||
if _credentials_project is None:
|
||||
raise Exception(
|
||||
"Invalid Langsmith API Key given. _credentials_project=None."
|
||||
)
|
||||
_credentials_base_url = (
|
||||
langsmith_base_url
|
||||
or os.getenv("LANGSMITH_BASE_URL")
|
||||
or "https://api.smith.langchain.com"
|
||||
)
|
||||
if _credentials_base_url is None:
|
||||
raise Exception(
|
||||
"Invalid Langsmith API Key given. _credentials_base_url=None."
|
||||
)
|
||||
|
||||
return LangsmithCredentialsObject(
|
||||
LANGSMITH_API_KEY=_credentials_api_key,
|
||||
LANGSMITH_BASE_URL=_credentials_base_url,
|
||||
LANGSMITH_PROJECT=_credentials_project,
|
||||
)
|
||||
|
||||
def _prepare_log_data(
|
||||
self,
|
||||
kwargs,
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
credentials: LangsmithCredentialsObject,
|
||||
):
|
||||
try:
|
||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
metadata = _litellm_params.get("metadata", {}) or {}
|
||||
project_name = metadata.get(
|
||||
"project_name", credentials["LANGSMITH_PROJECT"]
|
||||
)
|
||||
run_name = metadata.get("run_name", self.langsmith_default_run_name)
|
||||
run_id = metadata.get("id", metadata.get("run_id", None))
|
||||
parent_run_id = metadata.get("parent_run_id", None)
|
||||
trace_id = metadata.get("trace_id", None)
|
||||
session_id = metadata.get("session_id", None)
|
||||
dotted_order = metadata.get("dotted_order", None)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
)
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
raise Exception("Error logging request payload. Payload=none.")
|
||||
|
||||
metadata = payload[
|
||||
"metadata"
|
||||
] # ensure logged metadata is json serializable
|
||||
|
||||
data = {
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": payload,
|
||||
"outputs": payload["response"],
|
||||
"session_name": project_name,
|
||||
"start_time": payload["startTime"],
|
||||
"end_time": payload["endTime"],
|
||||
"tags": payload["request_tags"],
|
||||
"extra": metadata,
|
||||
}
|
||||
|
||||
if payload["error_str"] is not None and payload["status"] == "failure":
|
||||
data["error"] = payload["error_str"]
|
||||
|
||||
if run_id:
|
||||
data["id"] = run_id
|
||||
|
||||
if parent_run_id:
|
||||
data["parent_run_id"] = parent_run_id
|
||||
|
||||
if trace_id:
|
||||
data["trace_id"] = trace_id
|
||||
|
||||
if session_id:
|
||||
data["session_id"] = session_id
|
||||
|
||||
if dotted_order:
|
||||
data["dotted_order"] = dotted_order
|
||||
|
||||
run_id: Optional[str] = data.get("id") # type: ignore
|
||||
if "id" not in data or data["id"] is None:
|
||||
"""
|
||||
for /batch langsmith requires id, trace_id and dotted_order passed as params
|
||||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
data["id"] = run_id
|
||||
|
||||
if (
|
||||
"trace_id" not in data
|
||||
or data["trace_id"] is None
|
||||
and (run_id is not None and isinstance(run_id, str))
|
||||
):
|
||||
data["trace_id"] = run_id
|
||||
|
||||
if (
|
||||
"dotted_order" not in data
|
||||
or data["dotted_order"] is None
|
||||
and (run_id is not None and isinstance(run_id, str))
|
||||
):
|
||||
data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore
|
||||
|
||||
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
|
||||
|
||||
return data
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = (
|
||||
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
|
||||
except Exception:
|
||||
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Langsmith Layer Error - error logging async success event."
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||
try:
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Langsmith Layer Error - error logging async failure event."
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Handles sending batches of runs to Langsmith
|
||||
|
||||
self.log_queue contains LangsmithQueueObjects
|
||||
Each LangsmithQueueObject has the following:
|
||||
- "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url)
|
||||
- "data" - data to log on to langsmith for the request
|
||||
|
||||
|
||||
This function
|
||||
- groups the queue objects by credentials
|
||||
- loops through each unique credentials and sends batches to Langsmith
|
||||
|
||||
|
||||
This was added to support key/team based logging on langsmith
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
batch_groups = self._group_batches_by_credentials()
|
||||
for batch_group in batch_groups.values():
|
||||
await self._log_batch_on_langsmith(
|
||||
credentials=batch_group.credentials,
|
||||
queue_objects=batch_group.queue_objects,
|
||||
)
|
||||
|
||||
def _add_endpoint_to_url(
|
||||
self, url: str, endpoint: str, api_version: str = "/api/v1"
|
||||
) -> str:
|
||||
if api_version not in url:
|
||||
url = f"{url.rstrip('/')}{api_version}"
|
||||
|
||||
if url.endswith("/"):
|
||||
return f"{url}{endpoint}"
|
||||
return f"{url}/{endpoint}"
|
||||
|
||||
async def _log_batch_on_langsmith(
|
||||
self,
|
||||
credentials: LangsmithCredentialsObject,
|
||||
queue_objects: List[LangsmithQueueObject],
|
||||
):
|
||||
"""
|
||||
Logs a batch of runs to Langsmith
|
||||
sends runs to /batch endpoint for the given credentials
|
||||
|
||||
Args:
|
||||
credentials: LangsmithCredentialsObject
|
||||
queue_objects: List[LangsmithQueueObject]
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
langsmith_api_base = credentials["LANGSMITH_BASE_URL"]
|
||||
langsmith_api_key = credentials["LANGSMITH_API_KEY"]
|
||||
url = self._add_endpoint_to_url(langsmith_api_base, "runs/batch")
|
||||
headers = {"x-api-key": langsmith_api_key}
|
||||
elements_to_log = [queue_object["data"] for queue_object in queue_objects]
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Sending batch of %s runs to Langsmith", len(elements_to_log)
|
||||
)
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url,
|
||||
json={"post": elements_to_log},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Langsmith Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
f"Langsmith Layer Error - {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]:
|
||||
"""Groups queue objects by credentials using a proper key structure"""
|
||||
log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {}
|
||||
|
||||
for queue_object in self.log_queue:
|
||||
credentials = queue_object["credentials"]
|
||||
key = CredentialsKey(
|
||||
api_key=credentials["LANGSMITH_API_KEY"],
|
||||
project=credentials["LANGSMITH_PROJECT"],
|
||||
base_url=credentials["LANGSMITH_BASE_URL"],
|
||||
)
|
||||
|
||||
if key not in log_queue_by_credentials:
|
||||
log_queue_by_credentials[key] = BatchGroup(
|
||||
credentials=credentials, queue_objects=[]
|
||||
)
|
||||
|
||||
log_queue_by_credentials[key].queue_objects.append(queue_object)
|
||||
|
||||
return log_queue_by_credentials
|
||||
|
||||
def _get_credentials_to_use_for_request(
|
||||
self, kwargs: Dict[str, Any]
|
||||
) -> LangsmithCredentialsObject:
|
||||
"""
|
||||
Handles key/team based logging
|
||||
|
||||
If standard_callback_dynamic_params are provided, use those credentials.
|
||||
|
||||
Otherwise, use the default credentials.
|
||||
"""
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
if standard_callback_dynamic_params is not None:
|
||||
credentials = self.get_credentials_from_env(
|
||||
langsmith_api_key=standard_callback_dynamic_params.get(
|
||||
"langsmith_api_key", None
|
||||
),
|
||||
langsmith_project=standard_callback_dynamic_params.get(
|
||||
"langsmith_project", None
|
||||
),
|
||||
langsmith_base_url=standard_callback_dynamic_params.get(
|
||||
"langsmith_base_url", None
|
||||
),
|
||||
)
|
||||
else:
|
||||
credentials = self.default_credentials
|
||||
return credentials
|
||||
|
||||
def _send_batch(self):
|
||||
"""Calls async_send_batch in an event loop"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
try:
|
||||
# Try to get the existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If we're already in an event loop, create a task
|
||||
asyncio.create_task(self.async_send_batch())
|
||||
else:
|
||||
# If no event loop is running, run the coroutine directly
|
||||
loop.run_until_complete(self.async_send_batch())
|
||||
except RuntimeError:
|
||||
# If we can't get an event loop, create a new one
|
||||
asyncio.run(self.async_send_batch())
|
||||
|
||||
def get_run_by_id(self, run_id):
|
||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||
|
||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||
|
||||
url = f"{langsmith_api_base}/runs/{run_id}"
|
||||
response = litellm.module_level_client.get(
|
||||
url=url,
|
||||
headers={"x-api-key": langsmith_api_key},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def make_dot_order(self, run_id: str):
|
||||
st = datetime.now(timezone.utc)
|
||||
id_ = run_id
|
||||
return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_)
|
||||
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from litellm.proxy._types import SpanAttributes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class LangtraceAttributes:
|
||||
"""
|
||||
This class is used to save trace attributes to Langtrace's spans
|
||||
"""
|
||||
|
||||
def set_langtrace_attributes(self, span: Span, kwargs, response_obj):
|
||||
"""
|
||||
This function is used to log the event to Langtrace
|
||||
"""
|
||||
|
||||
vendor = kwargs.get("litellm_params").get("custom_llm_provider")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
options = {**kwargs, **optional_params}
|
||||
self.set_request_attributes(span, options, vendor)
|
||||
self.set_response_attributes(span, response_obj)
|
||||
self.set_usage_attributes(span, response_obj)
|
||||
|
||||
def set_request_attributes(self, span: Span, kwargs, vendor):
|
||||
"""
|
||||
This function is used to get span attributes for the LLM request
|
||||
"""
|
||||
span_attributes = {
|
||||
"gen_ai.operation.name": "chat",
|
||||
"langtrace.service.name": vendor,
|
||||
SpanAttributes.LLM_REQUEST_MODEL.value: kwargs.get("model"),
|
||||
SpanAttributes.LLM_IS_STREAMING.value: kwargs.get("stream"),
|
||||
SpanAttributes.LLM_REQUEST_TEMPERATURE.value: kwargs.get("temperature"),
|
||||
SpanAttributes.LLM_TOP_K.value: kwargs.get("top_k"),
|
||||
SpanAttributes.LLM_REQUEST_TOP_P.value: kwargs.get("top_p"),
|
||||
SpanAttributes.LLM_USER.value: kwargs.get("user"),
|
||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS.value: kwargs.get("max_tokens"),
|
||||
SpanAttributes.LLM_RESPONSE_STOP_REASON.value: kwargs.get("stop"),
|
||||
SpanAttributes.LLM_FREQUENCY_PENALTY.value: kwargs.get("frequency_penalty"),
|
||||
SpanAttributes.LLM_PRESENCE_PENALTY.value: kwargs.get("presence_penalty"),
|
||||
}
|
||||
|
||||
prompts = kwargs.get("messages")
|
||||
|
||||
if prompts:
|
||||
span.add_event(
|
||||
name="gen_ai.content.prompt",
|
||||
attributes={SpanAttributes.LLM_PROMPTS.value: json.dumps(prompts)},
|
||||
)
|
||||
|
||||
self.set_span_attributes(span, span_attributes)
|
||||
|
||||
def set_response_attributes(self, span: Span, response_obj):
|
||||
"""
|
||||
This function is used to get span attributes for the LLM response
|
||||
"""
|
||||
response_attributes = {
|
||||
"gen_ai.response_id": response_obj.get("id"),
|
||||
"gen_ai.system_fingerprint": response_obj.get("system_fingerprint"),
|
||||
SpanAttributes.LLM_RESPONSE_MODEL.value: response_obj.get("model"),
|
||||
}
|
||||
completions = []
|
||||
for choice in response_obj.get("choices", []):
|
||||
role = choice.get("message").get("role")
|
||||
content = choice.get("message").get("content")
|
||||
completions.append({"role": role, "content": content})
|
||||
|
||||
span.add_event(
|
||||
name="gen_ai.content.completion",
|
||||
attributes={SpanAttributes.LLM_COMPLETIONS: json.dumps(completions)},
|
||||
)
|
||||
|
||||
self.set_span_attributes(span, response_attributes)
|
||||
|
||||
def set_usage_attributes(self, span: Span, response_obj):
|
||||
"""
|
||||
This function is used to get span attributes for the LLM usage
|
||||
"""
|
||||
usage = response_obj.get("usage")
|
||||
if usage:
|
||||
usage_attributes = {
|
||||
SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value: usage.get(
|
||||
"prompt_tokens"
|
||||
),
|
||||
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value: usage.get(
|
||||
"completion_tokens"
|
||||
),
|
||||
SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value: usage.get("total_tokens"),
|
||||
}
|
||||
self.set_span_attributes(span, usage_attributes)
|
||||
|
||||
def set_span_attributes(self, span: Span, attributes):
|
||||
"""
|
||||
This function is used to set span attributes
|
||||
"""
|
||||
for key, value in attributes.items():
|
||||
if not value:
|
||||
continue
|
||||
span.set_attribute(key, value)
|
||||
@@ -0,0 +1,317 @@
|
||||
#### What this does ####
|
||||
# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform.
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class LiteralAILogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
literalai_api_key=None,
|
||||
literalai_api_url="https://cloud.getliteral.ai",
|
||||
env=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"),
|
||||
"x-client-name": "litellm",
|
||||
}
|
||||
if env:
|
||||
self.headers["x-env"] = env
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
batch_size = os.getenv("LITERAL_BATCH_SIZE", None)
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=int(batch_size) if batch_size else None,
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Literal AI Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Literal AI logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Literal AI Layer Error - error logging success event."
|
||||
)
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
verbose_logger.info("Literal AI Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Literal AI logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Literal AI Layer Error - error logging failure event."
|
||||
)
|
||||
|
||||
def _send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
url = f"{self.literalai_api_url}/api/graphql"
|
||||
query = self._steps_query_builder(self.log_queue)
|
||||
variables = self._steps_variables_builder(self.log_queue)
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=url,
|
||||
json={
|
||||
"query": query,
|
||||
"variables": variables,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Literal AI Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.exception("Literal AI Layer Error")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Literal AI Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Literal AI logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Literal AI Layer Error - error logging async success event."
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
verbose_logger.info("Literal AI Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Literal AI logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Literal AI Layer Error - error logging async failure event."
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
url = f"{self.literalai_api_url}/api/graphql"
|
||||
query = self._steps_query_builder(self.log_queue)
|
||||
variables = self._steps_variables_builder(self.log_queue)
|
||||
|
||||
try:
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url,
|
||||
json={
|
||||
"query": query,
|
||||
"variables": variables,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Literal AI Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.exception("Literal AI Layer Error")
|
||||
|
||||
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict:
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
clean_metadata = logging_payload["metadata"]
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
|
||||
settings = logging_payload["model_parameters"]
|
||||
messages = logging_payload["messages"]
|
||||
response = logging_payload["response"]
|
||||
choices: List = []
|
||||
if isinstance(response, dict) and "choices" in response:
|
||||
choices = response["choices"]
|
||||
message_completion = choices[0]["message"] if choices else None
|
||||
prompt_id = None
|
||||
variables = None
|
||||
|
||||
if messages and isinstance(messages, list) and isinstance(messages[0], dict):
|
||||
for message in messages:
|
||||
if literal_prompt := getattr(message, "__literal_prompt__", None):
|
||||
prompt_id = literal_prompt.get("prompt_id")
|
||||
variables = literal_prompt.get("variables")
|
||||
message["uuid"] = literal_prompt.get("uuid")
|
||||
message["templated"] = True
|
||||
|
||||
tools = settings.pop("tools", None)
|
||||
|
||||
step = {
|
||||
"id": metadata.get("step_id", str(uuid.uuid4())),
|
||||
"error": logging_payload["error_str"],
|
||||
"name": kwargs.get("model", ""),
|
||||
"threadId": metadata.get("literalai_thread_id", None),
|
||||
"parentId": metadata.get("literalai_parent_id", None),
|
||||
"rootRunId": metadata.get("literalai_root_run_id", None),
|
||||
"input": None,
|
||||
"output": None,
|
||||
"type": "llm",
|
||||
"tags": metadata.get("tags", metadata.get("literalai_tags", None)),
|
||||
"startTime": str(start_time),
|
||||
"endTime": str(end_time),
|
||||
"metadata": clean_metadata,
|
||||
"generation": {
|
||||
"inputTokenCount": logging_payload["prompt_tokens"],
|
||||
"outputTokenCount": logging_payload["completion_tokens"],
|
||||
"tokenCount": logging_payload["total_tokens"],
|
||||
"promptId": prompt_id,
|
||||
"variables": variables,
|
||||
"provider": kwargs.get("custom_llm_provider", "litellm"),
|
||||
"model": kwargs.get("model", ""),
|
||||
"duration": (end_time - start_time).total_seconds(),
|
||||
"settings": settings,
|
||||
"messages": messages,
|
||||
"messageCompletion": message_completion,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
return step
|
||||
|
||||
def _steps_query_variables_builder(self, steps):
|
||||
generated = ""
|
||||
for id in range(len(steps)):
|
||||
generated += f"""$id_{id}: String!
|
||||
$threadId_{id}: String
|
||||
$rootRunId_{id}: String
|
||||
$type_{id}: StepType
|
||||
$startTime_{id}: DateTime
|
||||
$endTime_{id}: DateTime
|
||||
$error_{id}: String
|
||||
$input_{id}: Json
|
||||
$output_{id}: Json
|
||||
$metadata_{id}: Json
|
||||
$parentId_{id}: String
|
||||
$name_{id}: String
|
||||
$tags_{id}: [String!]
|
||||
$generation_{id}: GenerationPayloadInput
|
||||
$scores_{id}: [ScorePayloadInput!]
|
||||
$attachments_{id}: [AttachmentPayloadInput!]
|
||||
"""
|
||||
return generated
|
||||
|
||||
def _steps_ingest_steps_builder(self, steps):
|
||||
generated = ""
|
||||
for id in range(len(steps)):
|
||||
generated += f"""
|
||||
step{id}: ingestStep(
|
||||
id: $id_{id}
|
||||
threadId: $threadId_{id}
|
||||
rootRunId: $rootRunId_{id}
|
||||
startTime: $startTime_{id}
|
||||
endTime: $endTime_{id}
|
||||
type: $type_{id}
|
||||
error: $error_{id}
|
||||
input: $input_{id}
|
||||
output: $output_{id}
|
||||
metadata: $metadata_{id}
|
||||
parentId: $parentId_{id}
|
||||
name: $name_{id}
|
||||
tags: $tags_{id}
|
||||
generation: $generation_{id}
|
||||
scores: $scores_{id}
|
||||
attachments: $attachments_{id}
|
||||
) {{
|
||||
ok
|
||||
message
|
||||
}}
|
||||
"""
|
||||
return generated
|
||||
|
||||
def _steps_query_builder(self, steps):
|
||||
return f"""
|
||||
mutation AddStep({self._steps_query_variables_builder(steps)}) {{
|
||||
{self._steps_ingest_steps_builder(steps)}
|
||||
}}
|
||||
"""
|
||||
|
||||
def _steps_variables_builder(self, steps):
|
||||
def serialize_step(event, id):
|
||||
result = {}
|
||||
|
||||
for key, value in event.items():
|
||||
# Only keep the keys that are not None to avoid overriding existing values
|
||||
if value is not None:
|
||||
result[f"{key}_{id}"] = value
|
||||
|
||||
return result
|
||||
|
||||
variables = {}
|
||||
for i in range(len(steps)):
|
||||
step = steps[i]
|
||||
variables.update(serialize_step(step, i))
|
||||
return variables
|
||||
@@ -0,0 +1,179 @@
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Logfire
|
||||
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, NamedTuple
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
||||
|
||||
|
||||
class SpanConfig(NamedTuple):
|
||||
message_template: LiteralString
|
||||
span_data: Dict[str, Any]
|
||||
|
||||
|
||||
class LogfireLevel(str, Enum):
|
||||
INFO = "info"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class LogfireLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
verbose_logger.debug("in init logfire logger")
|
||||
import logfire
|
||||
|
||||
# only setting up logfire if we are sending to logfire
|
||||
# in testing, we don't want to send to logfire
|
||||
if logfire.DEFAULT_LOGFIRE_INSTANCE.config.send_to_logfire:
|
||||
logfire.configure(token=os.getenv("LOGFIRE_TOKEN"))
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init logfire client {str(e)}")
|
||||
raise e
|
||||
|
||||
def _get_span_config(self, payload) -> SpanConfig:
|
||||
if (
|
||||
payload["call_type"] == "completion"
|
||||
or payload["call_type"] == "acompletion"
|
||||
):
|
||||
return SpanConfig(
|
||||
message_template="Chat Completion with {request_data[model]!r}",
|
||||
span_data={"request_data": payload},
|
||||
)
|
||||
elif (
|
||||
payload["call_type"] == "embedding" or payload["call_type"] == "aembedding"
|
||||
):
|
||||
return SpanConfig(
|
||||
message_template="Embedding Creation with {request_data[model]!r}",
|
||||
span_data={"request_data": payload},
|
||||
)
|
||||
elif (
|
||||
payload["call_type"] == "image_generation"
|
||||
or payload["call_type"] == "aimage_generation"
|
||||
):
|
||||
return SpanConfig(
|
||||
message_template="Image Generation with {request_data[model]!r}",
|
||||
span_data={"request_data": payload},
|
||||
)
|
||||
else:
|
||||
return SpanConfig(
|
||||
message_template="Litellm Call with {request_data[model]!r}",
|
||||
span_data={"request_data": payload},
|
||||
)
|
||||
|
||||
async def _async_log_event(
|
||||
self,
|
||||
kwargs,
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
print_verbose,
|
||||
level: LogfireLevel,
|
||||
):
|
||||
self.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
level=level,
|
||||
)
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
kwargs,
|
||||
start_time,
|
||||
end_time,
|
||||
print_verbose,
|
||||
level: LogfireLevel,
|
||||
response_obj,
|
||||
):
|
||||
try:
|
||||
import logfire
|
||||
|
||||
verbose_logger.debug(
|
||||
f"logfire Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
if not response_obj:
|
||||
response_obj = {}
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj.get("usage", {})
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
try:
|
||||
response_time = (end_time - start_time).total_seconds()
|
||||
except Exception:
|
||||
response_time = None
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
clean_metadata = redact_user_api_key_info(metadata=clean_metadata)
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"cache_hit": cache_hit,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"responseTime (seconds)": response_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"spend": kwargs.get("response_cost", 0),
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": clean_metadata,
|
||||
}
|
||||
logfire_openai = logfire.with_settings(custom_scope_suffix="openai")
|
||||
message_template, span_data = self._get_span_config(payload)
|
||||
if level == LogfireLevel.INFO:
|
||||
logfire_openai.info(
|
||||
message_template,
|
||||
**span_data,
|
||||
)
|
||||
elif level == LogfireLevel.ERROR:
|
||||
logfire_openai.error(
|
||||
message_template,
|
||||
**span_data,
|
||||
_exc_info=True,
|
||||
)
|
||||
print_verbose(f"\ndd Logger - Logging payload = {payload}")
|
||||
|
||||
print_verbose(
|
||||
f"Logfire Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
@@ -0,0 +1,179 @@
|
||||
#### What this does ####
|
||||
# On success + failure, log events to lunary.ai
|
||||
import importlib
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import packaging
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
return {
|
||||
"completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
|
||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
|
||||
def parse_tool_calls(tool_calls):
|
||||
if tool_calls is None:
|
||||
return None
|
||||
|
||||
def clean_tool_call(tool_call):
|
||||
serialized = {
|
||||
"type": tool_call.type,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
def clean_message(message):
|
||||
# if is string, return as is
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
if "message" in message:
|
||||
return clean_message(message["message"])
|
||||
|
||||
serialized = {
|
||||
"role": message.get("role"),
|
||||
"content": message.get("content"),
|
||||
}
|
||||
|
||||
# Only add tool_calls and function_call to res if they are set
|
||||
if message.get("tool_calls"):
|
||||
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||
|
||||
return serialized
|
||||
|
||||
if isinstance(input, list):
|
||||
if len(input) == 1:
|
||||
return clean_message(input[0])
|
||||
else:
|
||||
return [clean_message(msg) for msg in input]
|
||||
else:
|
||||
return clean_message(input)
|
||||
|
||||
|
||||
class LunaryLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
import lunary
|
||||
|
||||
version = importlib.metadata.version("lunary") # type: ignore
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
except ImportError:
|
||||
print( # noqa
|
||||
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||
) # noqa
|
||||
raise ImportError
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
kwargs,
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
model,
|
||||
print_verbose,
|
||||
extra={},
|
||||
input=None,
|
||||
user_id=None,
|
||||
response_obj=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
error=None,
|
||||
):
|
||||
try:
|
||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||
|
||||
template_id = None
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
if optional_params:
|
||||
extra = {**extra, **optional_params}
|
||||
|
||||
tags = metadata.get("tags", None)
|
||||
|
||||
if extra:
|
||||
extra.pop("extra_body", None)
|
||||
extra.pop("user", None)
|
||||
template_id = extra.pop("extra_headers", {}).get("Template-Id", None)
|
||||
|
||||
# keep only serializable types
|
||||
for param, value in extra.items():
|
||||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||
try:
|
||||
extra[param] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response_obj:
|
||||
usage = (
|
||||
parse_usage(response_obj["usage"])
|
||||
if "usage" in response_obj
|
||||
else None
|
||||
)
|
||||
|
||||
output = response_obj["choices"] if "choices" in response_obj else None
|
||||
|
||||
else:
|
||||
usage = None
|
||||
output = None
|
||||
|
||||
if error:
|
||||
error_obj = {"stack": error}
|
||||
else:
|
||||
error_obj = None
|
||||
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
"start",
|
||||
run_id,
|
||||
parent_run_id=metadata.get("parent_run_id", None),
|
||||
user_id=user_id,
|
||||
name=model,
|
||||
input=parse_messages(input),
|
||||
timestamp=start_time.astimezone(timezone.utc).isoformat(),
|
||||
template_id=template_id,
|
||||
metadata=metadata,
|
||||
runtime="litellm",
|
||||
tags=tags,
|
||||
params=extra,
|
||||
)
|
||||
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
timestamp=end_time.astimezone(timezone.utc).isoformat(),
|
||||
runtime="litellm",
|
||||
error=error_obj,
|
||||
output=parse_messages(output),
|
||||
token_usage=usage,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
@@ -0,0 +1,272 @@
|
||||
import json
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class MlflowLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
self._client = MlflowClient()
|
||||
|
||||
self._stream_id_to_span = {}
|
||||
self._lock = threading.Lock() # lock for _stream_id_to_span
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self._handle_success(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self._handle_success(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
def _handle_success(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Log the success event as an MLflow span.
|
||||
Note that this method is called asynchronously in the background thread.
|
||||
"""
|
||||
from mlflow.entities import SpanStatusCode
|
||||
|
||||
try:
|
||||
verbose_logger.debug("MLflow logging start for success event")
|
||||
|
||||
if kwargs.get("stream"):
|
||||
self._handle_stream_event(kwargs, response_obj, start_time, end_time)
|
||||
else:
|
||||
span = self._start_span_or_trace(kwargs, start_time)
|
||||
end_time_ns = int(end_time.timestamp() * 1e9)
|
||||
self._extract_and_set_chat_attributes(span, kwargs, response_obj)
|
||||
self._end_span_or_trace(
|
||||
span=span,
|
||||
outputs=response_obj,
|
||||
status=SpanStatusCode.OK,
|
||||
end_time_ns=end_time_ns,
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug("MLflow Logging Error", stack_info=True)
|
||||
|
||||
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
|
||||
try:
|
||||
from mlflow.tracing.utils import set_span_chat_messages # type: ignore
|
||||
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
inputs = self._construct_input(kwargs)
|
||||
input_messages = inputs.get("messages", [])
|
||||
output_messages = [
|
||||
c.message.model_dump(exclude_none=True)
|
||||
for c in getattr(response_obj, "choices", [])
|
||||
]
|
||||
if messages := [*input_messages, *output_messages]:
|
||||
set_span_chat_messages(span, messages)
|
||||
if tools := inputs.get("tools"):
|
||||
set_span_chat_tools(span, tools)
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self._handle_failure(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self._handle_failure(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Log the failure event as an MLflow span.
|
||||
Note that this method is called *synchronously* unlike the success handler.
|
||||
"""
|
||||
from mlflow.entities import SpanEvent, SpanStatusCode
|
||||
|
||||
try:
|
||||
span = self._start_span_or_trace(kwargs, start_time)
|
||||
|
||||
end_time_ns = int(end_time.timestamp() * 1e9)
|
||||
|
||||
# Record exception info as event
|
||||
if exception := kwargs.get("exception"):
|
||||
span.add_event(SpanEvent.from_exception(exception)) # type: ignore
|
||||
|
||||
self._extract_and_set_chat_attributes(span, kwargs, response_obj)
|
||||
self._end_span_or_trace(
|
||||
span=span,
|
||||
outputs=response_obj,
|
||||
status=SpanStatusCode.ERROR,
|
||||
end_time_ns=end_time_ns,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
|
||||
|
||||
def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Handle the success event for a streaming response. For streaming calls,
|
||||
log_success_event handle is triggered for every chunk of the stream.
|
||||
We create a single span for the entire stream request as follows:
|
||||
|
||||
1. For the first chunk, start a new span and store it in the map.
|
||||
2. For subsequent chunks, add the chunk as an event to the span.
|
||||
3. For the final chunk, end the span and remove the span from the map.
|
||||
"""
|
||||
from mlflow.entities import SpanStatusCode
|
||||
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
|
||||
if litellm_call_id not in self._stream_id_to_span:
|
||||
with self._lock:
|
||||
# Check again after acquiring lock
|
||||
if litellm_call_id not in self._stream_id_to_span:
|
||||
# Start a new span for the first chunk of the stream
|
||||
span = self._start_span_or_trace(kwargs, start_time)
|
||||
self._stream_id_to_span[litellm_call_id] = span
|
||||
|
||||
# Add chunk as event to the span
|
||||
span = self._stream_id_to_span[litellm_call_id]
|
||||
self._add_chunk_events(span, response_obj)
|
||||
|
||||
# If this is the final chunk, end the span. The final chunk
|
||||
# has complete_streaming_response that gathers the full response.
|
||||
if final_response := kwargs.get("complete_streaming_response"):
|
||||
end_time_ns = int(end_time.timestamp() * 1e9)
|
||||
|
||||
self._extract_and_set_chat_attributes(span, kwargs, final_response)
|
||||
self._end_span_or_trace(
|
||||
span=span,
|
||||
outputs=final_response,
|
||||
status=SpanStatusCode.OK,
|
||||
end_time_ns=end_time_ns,
|
||||
)
|
||||
|
||||
# Remove the stream_id from the map
|
||||
with self._lock:
|
||||
self._stream_id_to_span.pop(litellm_call_id)
|
||||
|
||||
def _add_chunk_events(self, span, response_obj):
|
||||
from mlflow.entities import SpanEvent
|
||||
|
||||
try:
|
||||
for choice in response_obj.choices:
|
||||
span.add_event(
|
||||
SpanEvent(
|
||||
name="streaming_chunk",
|
||||
attributes={"delta": json.dumps(choice.delta.model_dump())},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug("Error adding chunk events to span", stack_info=True)
|
||||
|
||||
def _construct_input(self, kwargs):
|
||||
"""Construct span inputs with optional parameters"""
|
||||
inputs = {"messages": kwargs.get("messages")}
|
||||
if tools := kwargs.get("tools"):
|
||||
inputs["tools"] = tools
|
||||
|
||||
for key in ["functions", "tools", "stream", "tool_choice", "user"]:
|
||||
if value := kwargs.get("optional_params", {}).pop(key, None):
|
||||
inputs[key] = value
|
||||
return inputs
|
||||
|
||||
def _extract_attributes(self, kwargs):
|
||||
"""
|
||||
Extract span attributes from kwargs.
|
||||
|
||||
With the latest version of litellm, the standard_logging_object contains
|
||||
canonical information for logging. If it is not present, we extract
|
||||
subset of attributes from other kwargs.
|
||||
"""
|
||||
attributes = {
|
||||
"litellm_call_id": kwargs.get("litellm_call_id"),
|
||||
"call_type": kwargs.get("call_type"),
|
||||
"model": kwargs.get("model"),
|
||||
}
|
||||
standard_obj = kwargs.get("standard_logging_object")
|
||||
if standard_obj:
|
||||
attributes.update(
|
||||
{
|
||||
"api_base": standard_obj.get("api_base"),
|
||||
"cache_hit": standard_obj.get("cache_hit"),
|
||||
"usage": {
|
||||
"completion_tokens": standard_obj.get("completion_tokens"),
|
||||
"prompt_tokens": standard_obj.get("prompt_tokens"),
|
||||
"total_tokens": standard_obj.get("total_tokens"),
|
||||
},
|
||||
"raw_llm_response": standard_obj.get("response"),
|
||||
"response_cost": standard_obj.get("response_cost"),
|
||||
"saved_cache_cost": standard_obj.get("saved_cache_cost"),
|
||||
}
|
||||
)
|
||||
else:
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
attributes.update(
|
||||
{
|
||||
"model": kwargs.get("model"),
|
||||
"cache_hit": kwargs.get("cache_hit"),
|
||||
"custom_llm_provider": kwargs.get("custom_llm_provider"),
|
||||
"api_base": litellm_params.get("api_base"),
|
||||
"response_cost": kwargs.get("response_cost"),
|
||||
}
|
||||
)
|
||||
return attributes
|
||||
|
||||
def _get_span_type(self, call_type: Optional[str]) -> str:
|
||||
from mlflow.entities import SpanType
|
||||
|
||||
if call_type in ["completion", "acompletion"]:
|
||||
return SpanType.LLM
|
||||
elif call_type == "embeddings":
|
||||
return SpanType.EMBEDDING
|
||||
else:
|
||||
return SpanType.LLM
|
||||
|
||||
def _start_span_or_trace(self, kwargs, start_time):
|
||||
"""
|
||||
Start an MLflow span or a trace.
|
||||
|
||||
If there is an active span, we start a new span as a child of
|
||||
that span. Otherwise, we start a new trace.
|
||||
"""
|
||||
import mlflow
|
||||
|
||||
call_type = kwargs.get("call_type", "completion")
|
||||
span_name = f"litellm-{call_type}"
|
||||
span_type = self._get_span_type(call_type)
|
||||
start_time_ns = int(start_time.timestamp() * 1e9)
|
||||
|
||||
inputs = self._construct_input(kwargs)
|
||||
attributes = self._extract_attributes(kwargs)
|
||||
|
||||
if active_span := mlflow.get_current_active_span(): # type: ignore
|
||||
return self._client.start_span(
|
||||
name=span_name,
|
||||
request_id=active_span.request_id,
|
||||
parent_id=active_span.span_id,
|
||||
span_type=span_type,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time_ns=start_time_ns,
|
||||
)
|
||||
else:
|
||||
return self._client.start_trace(
|
||||
name=span_name,
|
||||
span_type=span_type,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time_ns=start_time_ns,
|
||||
)
|
||||
|
||||
def _end_span_or_trace(self, span, outputs, end_time_ns, status):
|
||||
"""End an MLflow span or a trace."""
|
||||
if span.parent_id is None:
|
||||
self._client.end_trace(
|
||||
request_id=span.request_id,
|
||||
outputs=outputs,
|
||||
status=status,
|
||||
end_time_ns=end_time_ns,
|
||||
)
|
||||
else:
|
||||
self._client.end_span(
|
||||
request_id=span.request_id,
|
||||
span_id=span.span_id,
|
||||
outputs=outputs,
|
||||
status=status,
|
||||
end_time_ns=end_time_ns,
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
# What is this?
|
||||
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class OpenMeterLogger(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment()
|
||||
self.async_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self):
|
||||
"""
|
||||
Expects
|
||||
OPENMETER_API_ENDPOINT,
|
||||
OPENMETER_API_KEY,
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if os.getenv("OPENMETER_API_KEY", None) is None:
|
||||
missing_keys.append("OPENMETER_API_KEY")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj):
|
||||
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
dt = get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
if (
|
||||
isinstance(response_obj, litellm.ModelResponse)
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
) and hasattr(response_obj, "usage"):
|
||||
usage = {
|
||||
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
|
||||
if not subject:
|
||||
raise Exception("OpenMeter: user is required")
|
||||
|
||||
return {
|
||||
"specversion": "1.0",
|
||||
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
||||
"id": call_id,
|
||||
"time": dt,
|
||||
"subject": subject,
|
||||
"source": "litellm-proxy",
|
||||
"data": {"model": model, "cost": cost, **usage},
|
||||
}
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("OPENMETER_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/cloudevents+json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"OpenMeter logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("OPENMETER_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/cloudevents+json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"OpenMeter logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user