structure saas with tools
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
from typing import Dict, Literal
|
||||
|
||||
from ._common import TaskProviderHelper
|
||||
from .black_forest_labs import BlackForestLabsTextToImageTask
|
||||
from .cerebras import CerebrasConversationalTask
|
||||
from .cohere import CohereConversationalTask
|
||||
from .fal_ai import (
|
||||
FalAIAutomaticSpeechRecognitionTask,
|
||||
FalAITextToImageTask,
|
||||
FalAITextToSpeechTask,
|
||||
FalAITextToVideoTask,
|
||||
)
|
||||
from .fireworks_ai import FireworksAIConversationalTask
|
||||
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
|
||||
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
|
||||
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
|
||||
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
|
||||
from .openai import OpenAIConversationalTask
|
||||
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
|
||||
from .sambanova import SambanovaConversationalTask
|
||||
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
||||
|
||||
|
||||
PROVIDER_T = Literal[
|
||||
"black-forest-labs",
|
||||
"cerebras",
|
||||
"cohere",
|
||||
"fal-ai",
|
||||
"fireworks-ai",
|
||||
"hf-inference",
|
||||
"hyperbolic",
|
||||
"nebius",
|
||||
"novita",
|
||||
"openai",
|
||||
"replicate",
|
||||
"sambanova",
|
||||
"together",
|
||||
]
|
||||
|
||||
PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
||||
"black-forest-labs": {
|
||||
"text-to-image": BlackForestLabsTextToImageTask(),
|
||||
},
|
||||
"cerebras": {
|
||||
"conversational": CerebrasConversationalTask(),
|
||||
},
|
||||
"cohere": {
|
||||
"conversational": CohereConversationalTask(),
|
||||
},
|
||||
"fal-ai": {
|
||||
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
|
||||
"text-to-image": FalAITextToImageTask(),
|
||||
"text-to-speech": FalAITextToSpeechTask(),
|
||||
"text-to-video": FalAITextToVideoTask(),
|
||||
},
|
||||
"fireworks-ai": {
|
||||
"conversational": FireworksAIConversationalTask(),
|
||||
},
|
||||
"hf-inference": {
|
||||
"text-to-image": HFInferenceTask("text-to-image"),
|
||||
"conversational": HFInferenceConversational(),
|
||||
"text-generation": HFInferenceTask("text-generation"),
|
||||
"text-classification": HFInferenceTask("text-classification"),
|
||||
"question-answering": HFInferenceTask("question-answering"),
|
||||
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
|
||||
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
|
||||
"fill-mask": HFInferenceTask("fill-mask"),
|
||||
"feature-extraction": HFInferenceTask("feature-extraction"),
|
||||
"image-classification": HFInferenceBinaryInputTask("image-classification"),
|
||||
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
|
||||
"document-question-answering": HFInferenceTask("document-question-answering"),
|
||||
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
|
||||
"object-detection": HFInferenceBinaryInputTask("object-detection"),
|
||||
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
|
||||
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
|
||||
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
|
||||
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
|
||||
"sentence-similarity": HFInferenceTask("sentence-similarity"),
|
||||
"table-question-answering": HFInferenceTask("table-question-answering"),
|
||||
"tabular-classification": HFInferenceTask("tabular-classification"),
|
||||
"text-to-speech": HFInferenceTask("text-to-speech"),
|
||||
"token-classification": HFInferenceTask("token-classification"),
|
||||
"translation": HFInferenceTask("translation"),
|
||||
"summarization": HFInferenceTask("summarization"),
|
||||
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
|
||||
},
|
||||
"hyperbolic": {
|
||||
"text-to-image": HyperbolicTextToImageTask(),
|
||||
"conversational": HyperbolicTextGenerationTask("conversational"),
|
||||
"text-generation": HyperbolicTextGenerationTask("text-generation"),
|
||||
},
|
||||
"nebius": {
|
||||
"text-to-image": NebiusTextToImageTask(),
|
||||
"conversational": NebiusConversationalTask(),
|
||||
"text-generation": NebiusTextGenerationTask(),
|
||||
},
|
||||
"novita": {
|
||||
"text-generation": NovitaTextGenerationTask(),
|
||||
"conversational": NovitaConversationalTask(),
|
||||
"text-to-video": NovitaTextToVideoTask(),
|
||||
},
|
||||
"openai": {
|
||||
"conversational": OpenAIConversationalTask(),
|
||||
},
|
||||
"replicate": {
|
||||
"text-to-image": ReplicateTask("text-to-image"),
|
||||
"text-to-speech": ReplicateTextToSpeechTask(),
|
||||
"text-to-video": ReplicateTask("text-to-video"),
|
||||
},
|
||||
"sambanova": {
|
||||
"conversational": SambanovaConversationalTask(),
|
||||
},
|
||||
"together": {
|
||||
"text-to-image": TogetherTextToImageTask(),
|
||||
"conversational": TogetherConversationalTask(),
|
||||
"text-generation": TogetherTextGenerationTask(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_helper(provider: PROVIDER_T, task: str) -> TaskProviderHelper:
|
||||
"""Get provider helper instance by name and task.
|
||||
|
||||
Args:
|
||||
provider (str): Name of the provider
|
||||
task (str): Name of the task
|
||||
|
||||
Returns:
|
||||
TaskProviderHelper: Helper instance for the specified provider and task
|
||||
|
||||
Raises:
|
||||
ValueError: If provider or task is not supported
|
||||
"""
|
||||
if provider not in PROVIDERS:
|
||||
raise ValueError(f"Provider '{provider}' not supported. Available providers: {list(PROVIDERS.keys())}")
|
||||
if task not in PROVIDERS[provider]:
|
||||
raise ValueError(
|
||||
f"Task '{task}' not supported for provider '{provider}'. "
|
||||
f"Available tasks: {list(PROVIDERS[provider].keys())}"
|
||||
)
|
||||
return PROVIDERS[provider][task]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,245 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import constants
|
||||
from huggingface_hub.inference._common import RequestParameters
|
||||
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Dev purposes only.
|
||||
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
||||
# for a given Inference Provider, you can add it to the following dictionary.
|
||||
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
|
||||
# "HF model ID" => "Model ID on Inference Provider's side"
|
||||
#
|
||||
# Example:
|
||||
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
||||
"cerebras": {},
|
||||
"cohere": {},
|
||||
"fal-ai": {},
|
||||
"fireworks-ai": {},
|
||||
"hf-inference": {},
|
||||
"hyperbolic": {},
|
||||
"nebius": {},
|
||||
"replicate": {},
|
||||
"sambanova": {},
|
||||
"together": {},
|
||||
}
|
||||
|
||||
|
||||
def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
class TaskProviderHelper:
|
||||
"""Base class for task-specific provider helpers."""
|
||||
|
||||
def __init__(self, provider: str, base_url: str, task: str) -> None:
|
||||
self.provider = provider
|
||||
self.task = task
|
||||
self.base_url = base_url
|
||||
|
||||
def prepare_request(
|
||||
self,
|
||||
*,
|
||||
inputs: Any,
|
||||
parameters: Dict[str, Any],
|
||||
headers: Dict,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
extra_payload: Optional[Dict[str, Any]] = None,
|
||||
) -> RequestParameters:
|
||||
"""
|
||||
Prepare the request to be sent to the provider.
|
||||
|
||||
Each step (api_key, model, headers, url, payload) can be customized in subclasses.
|
||||
"""
|
||||
# api_key from user, or local token, or raise error
|
||||
api_key = self._prepare_api_key(api_key)
|
||||
|
||||
# mapped model from HF model ID
|
||||
mapped_model = self._prepare_mapped_model(model)
|
||||
|
||||
# default HF headers + user headers (to customize in subclasses)
|
||||
headers = self._prepare_headers(headers, api_key)
|
||||
|
||||
# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
|
||||
url = self._prepare_url(api_key, mapped_model)
|
||||
|
||||
# prepare payload (to customize in subclasses)
|
||||
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
|
||||
if payload is not None:
|
||||
payload = recursive_merge(payload, extra_payload or {})
|
||||
|
||||
# body data (to customize in subclasses)
|
||||
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
|
||||
|
||||
# check if both payload and data are set and return
|
||||
if payload is not None and data is not None:
|
||||
raise ValueError("Both payload and data cannot be set in the same request.")
|
||||
if payload is None and data is None:
|
||||
raise ValueError("Either payload or data must be set in the request.")
|
||||
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
response: Union[bytes, Dict],
|
||||
request_params: Optional[RequestParameters] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Return the response in the expected format.
|
||||
|
||||
Override this method in subclasses for customized response handling."""
|
||||
return response
|
||||
|
||||
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
||||
"""Return the API key to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
if api_key is None:
|
||||
api_key = get_token()
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
f"You must provide an api_key to work with {self.provider} API or log in with `huggingface-cli login`."
|
||||
)
|
||||
return api_key
|
||||
|
||||
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
||||
"""Return the mapped model ID to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
if model is None:
|
||||
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
|
||||
|
||||
# hardcoded mapping for local testing
|
||||
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
|
||||
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
|
||||
|
||||
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
|
||||
if provider_mapping is None:
|
||||
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
|
||||
|
||||
if provider_mapping.task != self.task:
|
||||
raise ValueError(
|
||||
f"Model {model} is not supported for task {self.task} and provider {self.provider}. "
|
||||
f"Supported task: {provider_mapping.task}."
|
||||
)
|
||||
|
||||
if provider_mapping.status == "staging":
|
||||
logger.warning(
|
||||
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
|
||||
)
|
||||
return provider_mapping.provider_id
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
"""Return the headers to use for the request.
|
||||
|
||||
Override this method in subclasses for customized headers.
|
||||
"""
|
||||
return {**build_hf_headers(token=api_key), **headers}
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
"""Return the URL to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
base_url = self._prepare_base_url(api_key)
|
||||
route = self._prepare_route(mapped_model, api_key)
|
||||
return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
|
||||
|
||||
def _prepare_base_url(self, api_key: str) -> str:
|
||||
"""Return the base URL to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
# Route to the proxy if the api_key is a HF TOKEN
|
||||
if api_key.startswith("hf_"):
|
||||
logger.info(f"Calling '{self.provider}' provider through Hugging Face router.")
|
||||
return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider)
|
||||
else:
|
||||
logger.info(f"Calling '{self.provider}' provider directly.")
|
||||
return self.base_url
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
"""Return the route to use for the request.
|
||||
|
||||
Override this method in subclasses for customized routes.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
"""Return the payload to use for the request, as a dict.
|
||||
|
||||
Override this method in subclasses for customized payloads.
|
||||
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _prepare_payload_as_bytes(
|
||||
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
|
||||
) -> Optional[bytes]:
|
||||
"""Return the body to use for the request, as bytes.
|
||||
|
||||
Override this method in subclasses for customized body data.
|
||||
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class BaseConversationalTask(TaskProviderHelper):
|
||||
"""
|
||||
Base class for conversational (chat completion) tasks.
|
||||
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
|
||||
"""
|
||||
|
||||
def __init__(self, provider: str, base_url: str):
|
||||
super().__init__(provider=provider, base_url=base_url, task="conversational")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
|
||||
|
||||
|
||||
class BaseTextGenerationTask(TaskProviderHelper):
|
||||
"""
|
||||
Base class for text-generation (completion) tasks.
|
||||
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
|
||||
"""
|
||||
|
||||
def __init__(self, provider: str, base_url: str):
|
||||
super().__init__(provider=provider, base_url=base_url, task="text-generation")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _fetch_inference_provider_mapping(model: str) -> Dict:
|
||||
"""
|
||||
Fetch provider mappings for a model from the Hub.
|
||||
"""
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
|
||||
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
|
||||
provider_mapping = info.inference_provider_mapping
|
||||
if provider_mapping is None:
|
||||
raise ValueError(f"No provider mapping found for model {model}")
|
||||
return provider_mapping
|
||||
|
||||
|
||||
def recursive_merge(dict1: Dict, dict2: Dict) -> Dict:
|
||||
return {
|
||||
**dict1,
|
||||
**{
|
||||
key: recursive_merge(dict1[key], value)
|
||||
if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict))
|
||||
else value
|
||||
for key, value in dict2.items()
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import logging
|
||||
from huggingface_hub.utils._http import get_session
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_POLLING_ATTEMPTS = 6
|
||||
POLLING_INTERVAL = 1.0
|
||||
|
||||
|
||||
class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
headers = super()._prepare_headers(headers, api_key)
|
||||
if not api_key.startswith("hf_"):
|
||||
_ = headers.pop("authorization")
|
||||
headers["X-Key"] = api_key
|
||||
return headers
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return f"/v1/{mapped_model}"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "num_inference_steps" in parameters:
|
||||
parameters["steps"] = parameters.pop("num_inference_steps")
|
||||
if "guidance_scale" in parameters:
|
||||
parameters["guidance"] = parameters.pop("guidance_scale")
|
||||
|
||||
return {"prompt": inputs, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
"""
|
||||
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
||||
"""
|
||||
url = _as_dict(response).get("polling_url")
|
||||
session = get_session()
|
||||
for _ in range(MAX_POLLING_ATTEMPTS):
|
||||
time.sleep(POLLING_INTERVAL)
|
||||
|
||||
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
||||
response.raise_for_status() # type: ignore
|
||||
response_json: Dict = response.json() # type: ignore
|
||||
status = response_json.get("status")
|
||||
logger.info(
|
||||
f"Polling generation result from {url}. Current status: {status}. "
|
||||
f"Will retry after {POLLING_INTERVAL} seconds if not ready."
|
||||
)
|
||||
|
||||
if (
|
||||
status == "Ready"
|
||||
and isinstance(response_json.get("result"), dict)
|
||||
and (sample_url := response_json["result"].get("sample"))
|
||||
):
|
||||
image_resp = session.get(sample_url)
|
||||
image_resp.raise_for_status()
|
||||
return image_resp.content
|
||||
|
||||
raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.")
|
||||
@@ -0,0 +1,6 @@
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask
|
||||
|
||||
|
||||
class CerebrasConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="cerebras", base_url="https://api.cerebras.ai")
|
||||
@@ -0,0 +1,15 @@
|
||||
from huggingface_hub.inference._providers._common import (
|
||||
BaseConversationalTask,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDER = "cohere"
|
||||
_BASE_URL = "https://api.cohere.com"
|
||||
|
||||
|
||||
class CohereConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/compatibility/v1/chat/completions"
|
||||
@@ -0,0 +1,147 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import get_session, hf_raise_for_status
|
||||
from huggingface_hub.utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Arbitrary polling interval
|
||||
_POLLING_INTERVAL = 0.5
|
||||
|
||||
|
||||
class FalAITask(TaskProviderHelper, ABC):
|
||||
def __init__(self, task: str):
|
||||
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
headers = super()._prepare_headers(headers, api_key)
|
||||
if not api_key.startswith("hf_"):
|
||||
headers["authorization"] = f"Key {api_key}"
|
||||
return headers
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return f"/{mapped_model}"
|
||||
|
||||
|
||||
class FalAIAutomaticSpeechRecognitionTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("automatic-speech-recognition")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
|
||||
# If input is a URL, pass it directly
|
||||
audio_url = inputs
|
||||
else:
|
||||
# If input is a file path, read it first
|
||||
if isinstance(inputs, str):
|
||||
with open(inputs, "rb") as f:
|
||||
inputs = f.read()
|
||||
|
||||
audio_b64 = base64.b64encode(inputs).decode()
|
||||
content_type = "audio/mpeg"
|
||||
audio_url = f"data:{content_type};base64,{audio_b64}"
|
||||
|
||||
return {"audio_url": audio_url, **filter_none(parameters)}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
text = _as_dict(response)["text"]
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
|
||||
return text
|
||||
|
||||
|
||||
class FalAITextToImageTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-image")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "width" in parameters and "height" in parameters:
|
||||
parameters["image_size"] = {
|
||||
"width": parameters.pop("width"),
|
||||
"height": parameters.pop("height"),
|
||||
}
|
||||
return {"prompt": inputs, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
url = _as_dict(response)["images"][0]["url"]
|
||||
return get_session().get(url).content
|
||||
|
||||
|
||||
class FalAITextToSpeechTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-speech")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"lyrics": inputs, **filter_none(parameters)}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
url = _as_dict(response)["audio"]["url"]
|
||||
return get_session().get(url).content
|
||||
|
||||
|
||||
class FalAITextToVideoTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-video")
|
||||
|
||||
def _prepare_base_url(self, api_key: str) -> str:
|
||||
if api_key.startswith("hf_"):
|
||||
return super()._prepare_base_url(api_key)
|
||||
else:
|
||||
logger.info(f"Calling '{self.provider}' provider directly.")
|
||||
return "https://queue.fal.run"
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
if api_key.startswith("hf_"):
|
||||
# Use the queue subdomain for HF routing
|
||||
return f"/{mapped_model}?_subdomain=queue"
|
||||
return f"/{mapped_model}"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"prompt": inputs, **filter_none(parameters)}
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
response: Union[bytes, Dict],
|
||||
request_params: Optional[RequestParameters] = None,
|
||||
) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
|
||||
request_id = response_dict.get("request_id")
|
||||
if not request_id:
|
||||
raise ValueError("No request ID found in the response")
|
||||
if request_params is None:
|
||||
raise ValueError(
|
||||
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
|
||||
)
|
||||
|
||||
# extract the base url and query params
|
||||
parsed_url = urlparse(request_params.url)
|
||||
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
|
||||
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
|
||||
|
||||
# extracting the provider model id for status and result urls
|
||||
# from the response as it might be different from the mapped model in `request_params.url`
|
||||
model_id = urlparse(response_dict.get("response_url")).path
|
||||
status_url = f"{base_url}{str(model_id)}/status{query_param}"
|
||||
result_url = f"{base_url}{str(model_id)}{query_param}"
|
||||
|
||||
status = response_dict.get("status")
|
||||
logger.info("Generating the video.. this can take several minutes.")
|
||||
while status != "COMPLETED":
|
||||
time.sleep(_POLLING_INTERVAL)
|
||||
status_response = get_session().get(status_url, headers=request_params.headers)
|
||||
hf_raise_for_status(status_response)
|
||||
status = status_response.json().get("status")
|
||||
|
||||
response = get_session().get(result_url, headers=request_params.headers).json()
|
||||
url = _as_dict(response)["video"]["url"]
|
||||
return get_session().get(url).content
|
||||
@@ -0,0 +1,9 @@
|
||||
from ._common import BaseConversationalTask
|
||||
|
||||
|
||||
class FireworksAIConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/inference/v1/chat/completions"
|
||||
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from huggingface_hub import constants
|
||||
from huggingface_hub.inference._common import _b64_encode, _open_as_binary
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
|
||||
|
||||
|
||||
class HFInferenceTask(TaskProviderHelper):
|
||||
"""Base class for HF Inference API tasks."""
|
||||
|
||||
def __init__(self, task: str):
|
||||
super().__init__(
|
||||
provider="hf-inference",
|
||||
base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"),
|
||||
task=task,
|
||||
)
|
||||
|
||||
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
||||
# special case: for HF Inference we allow not providing an API key
|
||||
return api_key or get_token() # type: ignore[return-value]
|
||||
|
||||
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
||||
if model is not None and model.startswith(("http://", "https://")):
|
||||
return model
|
||||
model_id = model if model is not None else _fetch_recommended_models().get(self.task)
|
||||
if model_id is None:
|
||||
raise ValueError(
|
||||
f"Task {self.task} has no recommended model for HF Inference. Please specify a model"
|
||||
" explicitly. Visit https://huggingface.co/tasks for more info."
|
||||
)
|
||||
_check_supported_task(model_id, self.task)
|
||||
return model_id
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
|
||||
if mapped_model.startswith(("http://", "https://")):
|
||||
return mapped_model
|
||||
return (
|
||||
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
|
||||
f"{self.base_url}/pipeline/{self.task}/{mapped_model}"
|
||||
if self.task in ("feature-extraction", "sentence-similarity")
|
||||
# Otherwise, we use the default endpoint
|
||||
else f"{self.base_url}/models/{mapped_model}"
|
||||
)
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
if isinstance(inputs, bytes):
|
||||
raise ValueError(f"Unexpected binary input for task {self.task}.")
|
||||
if isinstance(inputs, Path):
|
||||
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
|
||||
return {"inputs": inputs, "parameters": filter_none(parameters)}
|
||||
|
||||
|
||||
class HFInferenceBinaryInputTask(HFInferenceTask):
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return None
|
||||
|
||||
def _prepare_payload_as_bytes(
|
||||
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
|
||||
) -> Optional[bytes]:
|
||||
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
|
||||
extra_payload = extra_payload or {}
|
||||
has_parameters = len(parameters) > 0 or len(extra_payload) > 0
|
||||
|
||||
# Raise if not a binary object or a local path or a URL.
|
||||
if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str):
|
||||
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}")
|
||||
|
||||
# Send inputs as raw content when no parameters are provided
|
||||
if not has_parameters:
|
||||
with _open_as_binary(inputs) as data:
|
||||
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
||||
return data_as_bytes
|
||||
|
||||
# Otherwise encode as b64
|
||||
return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8")
|
||||
|
||||
|
||||
class HFInferenceConversational(HFInferenceTask):
|
||||
def __init__(self):
|
||||
super().__init__("conversational")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
payload_model = parameters.get("model") or mapped_model
|
||||
|
||||
if payload_model is None or payload_model.startswith(("http://", "https://")):
|
||||
payload_model = "dummy"
|
||||
|
||||
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
base_url = (
|
||||
mapped_model
|
||||
if mapped_model.startswith(("http://", "https://"))
|
||||
else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}"
|
||||
)
|
||||
return _build_chat_completion_url(base_url)
|
||||
|
||||
|
||||
def _build_chat_completion_url(model_url: str) -> str:
|
||||
# Strip trailing /
|
||||
model_url = model_url.rstrip("/")
|
||||
|
||||
# Append /chat/completions if not already present
|
||||
if model_url.endswith("/v1"):
|
||||
model_url += "/chat/completions"
|
||||
|
||||
# Append /v1/chat/completions if not already present
|
||||
if not model_url.endswith("/chat/completions"):
|
||||
model_url += "/v1/chat/completions"
|
||||
|
||||
return model_url
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
|
||||
response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers())
|
||||
hf_raise_for_status(response)
|
||||
return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _check_supported_task(model: str, task: str) -> None:
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
|
||||
model_info = HfApi().model_info(model)
|
||||
pipeline_tag = model_info.pipeline_tag
|
||||
tags = model_info.tags or []
|
||||
is_conversational = "conversational" in tags
|
||||
if task in ("text-generation", "conversational"):
|
||||
if pipeline_tag == "text-generation":
|
||||
# text-generation + conversational tag -> both tasks allowed
|
||||
if is_conversational:
|
||||
return
|
||||
# text-generation without conversational tag -> only text-generation allowed
|
||||
if task == "text-generation":
|
||||
return
|
||||
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
|
||||
|
||||
if pipeline_tag == "text2text-generation":
|
||||
if task == "text-generation":
|
||||
return
|
||||
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
|
||||
|
||||
if pipeline_tag == "image-text-to-text":
|
||||
if is_conversational and task == "conversational":
|
||||
return # Only conversational allowed if tagged as conversational
|
||||
raise ValueError("Non-conversational image-text-to-text task is not supported.")
|
||||
|
||||
if (
|
||||
task in ("feature-extraction", "sentence-similarity")
|
||||
and pipeline_tag in ("feature-extraction", "sentence-similarity")
|
||||
and task in tags
|
||||
):
|
||||
# feature-extraction and sentence-similarity are interchangeable for HF Inference
|
||||
return
|
||||
|
||||
# For all other tasks, just check pipeline tag
|
||||
if pipeline_tag != task:
|
||||
raise ValueError(
|
||||
f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'"
|
||||
)
|
||||
return
|
||||
@@ -0,0 +1,43 @@
|
||||
import base64
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
|
||||
|
||||
|
||||
class HyperbolicTextToImageTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/images/generations"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "num_inference_steps" in parameters:
|
||||
parameters["steps"] = parameters.pop("num_inference_steps")
|
||||
if "guidance_scale" in parameters:
|
||||
parameters["cfg_scale"] = parameters.pop("guidance_scale")
|
||||
# For Hyperbolic, the width and height are required parameters
|
||||
if "width" not in parameters:
|
||||
parameters["width"] = 512
|
||||
if "height" not in parameters:
|
||||
parameters["height"] = 512
|
||||
return {"prompt": inputs, "model_name": mapped_model, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
return base64.b64decode(response_dict["images"][0]["image"])
|
||||
|
||||
|
||||
class HyperbolicTextGenerationTask(BaseConversationalTask):
|
||||
"""
|
||||
Special case for Hyperbolic, where text-generation task is handled as a conversational task.
|
||||
"""
|
||||
|
||||
def __init__(self, task: str):
|
||||
super().__init__(
|
||||
provider="hyperbolic",
|
||||
base_url="https://api.hyperbolic.xyz",
|
||||
)
|
||||
self.task = task
|
||||
@@ -0,0 +1,51 @@
|
||||
import base64
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import (
|
||||
BaseConversationalTask,
|
||||
BaseTextGenerationTask,
|
||||
TaskProviderHelper,
|
||||
filter_none,
|
||||
)
|
||||
|
||||
|
||||
class NebiusTextGenerationTask(BaseTextGenerationTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
output = _as_dict(response)["choices"][0]
|
||||
return {
|
||||
"generated_text": output["text"],
|
||||
"details": {
|
||||
"finish_reason": output.get("finish_reason"),
|
||||
"seed": output.get("seed"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NebiusConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")
|
||||
|
||||
|
||||
class NebiusTextToImageTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/images/generations"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "guidance_scale" in parameters:
|
||||
parameters.pop("guidance_scale")
|
||||
if parameters.get("response_format") not in ("b64_json", "url"):
|
||||
parameters["response_format"] = "b64_json"
|
||||
|
||||
return {"prompt": inputs, **parameters, "model": mapped_model}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
return base64.b64decode(response_dict["data"][0]["b64_json"])
|
||||
@@ -0,0 +1,66 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import (
|
||||
BaseConversationalTask,
|
||||
BaseTextGenerationTask,
|
||||
TaskProviderHelper,
|
||||
filter_none,
|
||||
)
|
||||
from huggingface_hub.utils import get_session
|
||||
|
||||
|
||||
_PROVIDER = "novita"
|
||||
_BASE_URL = "https://api.novita.ai"
|
||||
|
||||
|
||||
class NovitaTextGenerationTask(BaseTextGenerationTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
# there is no v1/ route for novita
|
||||
return "/v3/openai/completions"
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
output = _as_dict(response)["choices"][0]
|
||||
return {
|
||||
"generated_text": output["text"],
|
||||
"details": {
|
||||
"finish_reason": output.get("finish_reason"),
|
||||
"seed": output.get("seed"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NovitaConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
# there is no v1/ route for novita
|
||||
return "/v3/openai/chat/completions"
|
||||
|
||||
|
||||
class NovitaTextToVideoTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task="text-to-video")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return f"/v3/hf/{mapped_model}"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"prompt": inputs, **filter_none(parameters)}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
if not (
|
||||
isinstance(response_dict, dict)
|
||||
and "video" in response_dict
|
||||
and isinstance(response_dict["video"], dict)
|
||||
and "video_url" in response_dict["video"]
|
||||
):
|
||||
raise ValueError("Expected response format: { 'video': { 'video_url': string } }")
|
||||
|
||||
video_url = response_dict["video"]["video_url"]
|
||||
return get_session().get(video_url).content
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask
|
||||
|
||||
|
||||
class OpenAIConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="openai", base_url="https://api.openai.com")
|
||||
|
||||
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
||||
if api_key is None:
|
||||
raise ValueError("You must provide an api_key to work with OpenAI API.")
|
||||
if api_key.startswith("hf_"):
|
||||
raise ValueError(
|
||||
"OpenAI provider is not available through Hugging Face routing, please use your own OpenAI API key."
|
||||
)
|
||||
return api_key
|
||||
|
||||
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
||||
if model is None:
|
||||
raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.")
|
||||
return model
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import get_session
|
||||
|
||||
|
||||
_PROVIDER = "replicate"
|
||||
_BASE_URL = "https://api.replicate.com"
|
||||
|
||||
|
||||
class ReplicateTask(TaskProviderHelper):
|
||||
def __init__(self, task: str):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
headers = super()._prepare_headers(headers, api_key)
|
||||
headers["Prefer"] = "wait"
|
||||
return headers
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
if ":" in mapped_model:
|
||||
return "/v1/predictions"
|
||||
return f"/v1/models/{mapped_model}/predictions"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
|
||||
if ":" in mapped_model:
|
||||
version = mapped_model.split(":", 1)[1]
|
||||
payload["version"] = version
|
||||
return payload
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
if response_dict.get("output") is None:
|
||||
raise TimeoutError(
|
||||
f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}"
|
||||
"The model might be in cold state or starting up. Please try again later."
|
||||
)
|
||||
output_url = (
|
||||
response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0]
|
||||
)
|
||||
return get_session().get(output_url).content
|
||||
|
||||
|
||||
class ReplicateTextToSpeechTask(ReplicateTask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-speech")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, mapped_model) # type: ignore[assignment]
|
||||
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
|
||||
return payload
|
||||
@@ -0,0 +1,6 @@
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask
|
||||
|
||||
|
||||
class SambanovaConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")
|
||||
@@ -0,0 +1,69 @@
|
||||
import base64
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import (
|
||||
BaseConversationalTask,
|
||||
BaseTextGenerationTask,
|
||||
TaskProviderHelper,
|
||||
filter_none,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDER = "together"
|
||||
_BASE_URL = "https://api.together.xyz"
|
||||
|
||||
|
||||
class TogetherTask(TaskProviderHelper, ABC):
|
||||
"""Base class for Together API tasks."""
|
||||
|
||||
def __init__(self, task: str):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
if self.task == "text-to-image":
|
||||
return "/v1/images/generations"
|
||||
elif self.task == "conversational":
|
||||
return "/v1/chat/completions"
|
||||
elif self.task == "text-generation":
|
||||
return "/v1/completions"
|
||||
raise ValueError(f"Unsupported task '{self.task}' for Together API.")
|
||||
|
||||
|
||||
class TogetherTextGenerationTask(BaseTextGenerationTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
output = _as_dict(response)["choices"][0]
|
||||
return {
|
||||
"generated_text": output["text"],
|
||||
"details": {
|
||||
"finish_reason": output.get("finish_reason"),
|
||||
"seed": output.get("seed"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TogetherConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
|
||||
class TogetherTextToImageTask(TogetherTask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-image")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "num_inference_steps" in parameters:
|
||||
parameters["steps"] = parameters.pop("num_inference_steps")
|
||||
if "guidance_scale" in parameters:
|
||||
parameters["guidance"] = parameters.pop("guidance_scale")
|
||||
|
||||
return {"prompt": inputs, "response_format": "base64", **parameters, "model": mapped_model}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
return base64.b64decode(response_dict["data"][0]["b64_json"])
|
||||
Reference in New Issue
Block a user