structure saas with tools
This commit is contained in:
@@ -0,0 +1,179 @@
|
||||
#### What this does ####
|
||||
# On success + failure, log events to lunary.ai
|
||||
import importlib
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import packaging
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
return {
|
||||
"completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
|
||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
|
||||
def parse_tool_calls(tool_calls):
|
||||
if tool_calls is None:
|
||||
return None
|
||||
|
||||
def clean_tool_call(tool_call):
|
||||
serialized = {
|
||||
"type": tool_call.type,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
def clean_message(message):
|
||||
# if is string, return as is
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
if "message" in message:
|
||||
return clean_message(message["message"])
|
||||
|
||||
serialized = {
|
||||
"role": message.get("role"),
|
||||
"content": message.get("content"),
|
||||
}
|
||||
|
||||
# Only add tool_calls and function_call to res if they are set
|
||||
if message.get("tool_calls"):
|
||||
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||
|
||||
return serialized
|
||||
|
||||
if isinstance(input, list):
|
||||
if len(input) == 1:
|
||||
return clean_message(input[0])
|
||||
else:
|
||||
return [clean_message(msg) for msg in input]
|
||||
else:
|
||||
return clean_message(input)
|
||||
|
||||
|
||||
class LunaryLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
import lunary
|
||||
|
||||
version = importlib.metadata.version("lunary") # type: ignore
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
except ImportError:
|
||||
print( # noqa
|
||||
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||
) # noqa
|
||||
raise ImportError
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
kwargs,
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
model,
|
||||
print_verbose,
|
||||
extra={},
|
||||
input=None,
|
||||
user_id=None,
|
||||
response_obj=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
error=None,
|
||||
):
|
||||
try:
|
||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||
|
||||
template_id = None
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
if optional_params:
|
||||
extra = {**extra, **optional_params}
|
||||
|
||||
tags = metadata.get("tags", None)
|
||||
|
||||
if extra:
|
||||
extra.pop("extra_body", None)
|
||||
extra.pop("user", None)
|
||||
template_id = extra.pop("extra_headers", {}).get("Template-Id", None)
|
||||
|
||||
# keep only serializable types
|
||||
for param, value in extra.items():
|
||||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||
try:
|
||||
extra[param] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response_obj:
|
||||
usage = (
|
||||
parse_usage(response_obj["usage"])
|
||||
if "usage" in response_obj
|
||||
else None
|
||||
)
|
||||
|
||||
output = response_obj["choices"] if "choices" in response_obj else None
|
||||
|
||||
else:
|
||||
usage = None
|
||||
output = None
|
||||
|
||||
if error:
|
||||
error_obj = {"stack": error}
|
||||
else:
|
||||
error_obj = None
|
||||
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
"start",
|
||||
run_id,
|
||||
parent_run_id=metadata.get("parent_run_id", None),
|
||||
user_id=user_id,
|
||||
name=model,
|
||||
input=parse_messages(input),
|
||||
timestamp=start_time.astimezone(timezone.utc).isoformat(),
|
||||
template_id=template_id,
|
||||
metadata=metadata,
|
||||
runtime="litellm",
|
||||
tags=tags,
|
||||
params=extra,
|
||||
)
|
||||
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
timestamp=end_time.astimezone(timezone.utc).isoformat(),
|
||||
runtime="litellm",
|
||||
error=error_obj,
|
||||
output=parse_messages(output),
|
||||
token_usage=usage,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
Reference in New Issue
Block a user