structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,422 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities used by both the sync and async inference clients."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterable,
|
||||
BinaryIO,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from huggingface_hub.errors import (
|
||||
GenerationError,
|
||||
IncompleteGenerationError,
|
||||
OverloadedError,
|
||||
TextGenerationError,
|
||||
UnknownError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available
|
||||
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientResponse, ClientSession
|
||||
from PIL.Image import Image
|
||||
|
||||
# TYPES
|
||||
UrlT = str
|
||||
PathT = Union[str, Path]
|
||||
BinaryT = Union[bytes, BinaryIO]
|
||||
ContentT = Union[BinaryT, PathT, UrlT]
|
||||
|
||||
# Use to set a Accept: image/png header
|
||||
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestParameters:
|
||||
url: str
|
||||
task: str
|
||||
model: Optional[str]
|
||||
json: Optional[Union[str, Dict, List]]
|
||||
data: Optional[ContentT]
|
||||
headers: Dict[str, Any]
|
||||
|
||||
|
||||
# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
|
||||
@dataclass
|
||||
class ModelStatus:
|
||||
"""
|
||||
This Dataclass represents the model status in the HF Inference API.
|
||||
|
||||
Args:
|
||||
loaded (`bool`):
|
||||
If the model is currently loaded into HF's Inference API. Models
|
||||
are loaded on-demand, leading to the user's first request taking longer.
|
||||
If a model is loaded, you can be assured that it is in a healthy state.
|
||||
state (`str`):
|
||||
The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
|
||||
If a model's state is 'Loadable', it's not too big and has a supported
|
||||
backend. Loadable models are automatically loaded when the user first
|
||||
requests inference on the endpoint. This means it is transparent for the
|
||||
user to load a model, except that the first call takes longer to complete.
|
||||
compute_type (`Dict`):
|
||||
Information about the compute resource the model is using or will use, such as 'gpu' type and number of
|
||||
replicas.
|
||||
framework (`str`):
|
||||
The name of the framework that the model was built with, such as 'transformers'
|
||||
or 'text-generation-inference'.
|
||||
"""
|
||||
|
||||
loaded: bool
|
||||
state: str
|
||||
compute_type: Dict
|
||||
framework: str
|
||||
|
||||
|
||||
## IMPORT UTILS
|
||||
|
||||
|
||||
def _import_aiohttp():
|
||||
# Make sure `aiohttp` is installed on the machine.
|
||||
if not is_aiohttp_available():
|
||||
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
|
||||
import aiohttp
|
||||
|
||||
return aiohttp
|
||||
|
||||
|
||||
def _import_numpy():
|
||||
"""Make sure `numpy` is installed on the machine."""
|
||||
if not is_numpy_available():
|
||||
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
|
||||
import numpy
|
||||
|
||||
return numpy
|
||||
|
||||
|
||||
def _import_pil_image():
|
||||
"""Make sure `PIL` is installed on the machine."""
|
||||
if not is_pillow_available():
|
||||
raise ImportError(
|
||||
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
|
||||
" post-processed, use `client.post(...)` and get the raw response from the server."
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
return Image
|
||||
|
||||
|
||||
## ENCODING / DECODING UTILS
|
||||
|
||||
|
||||
@overload
|
||||
def _open_as_binary(
|
||||
content: ContentT,
|
||||
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
|
||||
|
||||
|
||||
@overload
|
||||
def _open_as_binary(
|
||||
content: Literal[None],
|
||||
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
|
||||
|
||||
|
||||
@contextmanager # type: ignore
|
||||
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
|
||||
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes.
|
||||
|
||||
Do nothing if `content` is None,
|
||||
|
||||
TODO: handle a PIL.Image as input
|
||||
TODO: handle base64 as input
|
||||
"""
|
||||
# If content is a string => must be either a URL or a path
|
||||
if isinstance(content, str):
|
||||
if content.startswith("https://") or content.startswith("http://"):
|
||||
logger.debug(f"Downloading content from {content}")
|
||||
yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
|
||||
return
|
||||
content = Path(content)
|
||||
if not content.exists():
|
||||
raise FileNotFoundError(
|
||||
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
|
||||
" file. To pass raw content, please encode it as bytes first."
|
||||
)
|
||||
|
||||
# If content is a Path => open it
|
||||
if isinstance(content, Path):
|
||||
logger.debug(f"Opening content from {content}")
|
||||
with content.open("rb") as f:
|
||||
yield f
|
||||
else:
|
||||
# Otherwise: already a file-like object or None
|
||||
yield content
|
||||
|
||||
|
||||
def _b64_encode(content: ContentT) -> str:
|
||||
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
|
||||
with _open_as_binary(content) as data:
|
||||
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
||||
return base64.b64encode(data_as_bytes).decode()
|
||||
|
||||
|
||||
def _b64_to_image(encoded_image: str) -> "Image":
|
||||
"""Parse a base64-encoded string into a PIL Image."""
|
||||
Image = _import_pil_image()
|
||||
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
|
||||
|
||||
|
||||
def _bytes_to_list(content: bytes) -> List:
|
||||
"""Parse bytes from a Response object into a Python list.
|
||||
|
||||
Expects the response body to be JSON-encoded data.
|
||||
|
||||
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
|
||||
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
||||
"""
|
||||
return json.loads(content.decode())
|
||||
|
||||
|
||||
def _bytes_to_dict(content: bytes) -> Dict:
|
||||
"""Parse bytes from a Response object into a Python dictionary.
|
||||
|
||||
Expects the response body to be JSON-encoded data.
|
||||
|
||||
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
|
||||
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
||||
"""
|
||||
return json.loads(content.decode())
|
||||
|
||||
|
||||
def _bytes_to_image(content: bytes) -> "Image":
|
||||
"""Parse bytes from a Response object into a PIL Image.
|
||||
|
||||
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
|
||||
"""
|
||||
Image = _import_pil_image()
|
||||
return Image.open(io.BytesIO(content))
|
||||
|
||||
|
||||
def _as_dict(response: Union[bytes, Dict]) -> Dict:
|
||||
return json.loads(response) if isinstance(response, bytes) else response
|
||||
|
||||
|
||||
## PAYLOAD UTILS
|
||||
|
||||
|
||||
## STREAMING UTILS
|
||||
|
||||
|
||||
def _stream_text_generation_response(
|
||||
bytes_output_as_lines: Iterable[bytes], details: bool
|
||||
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
||||
"""Used in `InferenceClient.text_generation`."""
|
||||
# Parse ServerSentEvents
|
||||
for byte_payload in bytes_output_as_lines:
|
||||
try:
|
||||
output = _format_text_generation_stream_output(byte_payload, details)
|
||||
except StopIteration:
|
||||
break
|
||||
if output is not None:
|
||||
yield output
|
||||
|
||||
|
||||
async def _async_stream_text_generation_response(
|
||||
bytes_output_as_lines: AsyncIterable[bytes], details: bool
|
||||
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
||||
"""Used in `AsyncInferenceClient.text_generation`."""
|
||||
# Parse ServerSentEvents
|
||||
async for byte_payload in bytes_output_as_lines:
|
||||
try:
|
||||
output = _format_text_generation_stream_output(byte_payload, details)
|
||||
except StopIteration:
|
||||
break
|
||||
if output is not None:
|
||||
yield output
|
||||
|
||||
|
||||
def _format_text_generation_stream_output(
|
||||
byte_payload: bytes, details: bool
|
||||
) -> Optional[Union[str, TextGenerationStreamOutput]]:
|
||||
if not byte_payload.startswith(b"data:"):
|
||||
return None # empty line
|
||||
|
||||
if byte_payload.strip() == b"data: [DONE]":
|
||||
raise StopIteration("[DONE] signal received.")
|
||||
|
||||
# Decode payload
|
||||
payload = byte_payload.decode("utf-8")
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
|
||||
# Either an error as being returned
|
||||
if json_payload.get("error") is not None:
|
||||
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
||||
|
||||
# Or parse token payload
|
||||
output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
|
||||
return output.token.text if not details else output
|
||||
|
||||
|
||||
def _stream_chat_completion_response(
|
||||
bytes_lines: Iterable[bytes],
|
||||
) -> Iterable[ChatCompletionStreamOutput]:
|
||||
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
|
||||
for item in bytes_lines:
|
||||
try:
|
||||
output = _format_chat_completion_stream_output(item)
|
||||
except StopIteration:
|
||||
break
|
||||
if output is not None:
|
||||
yield output
|
||||
|
||||
|
||||
async def _async_stream_chat_completion_response(
|
||||
bytes_lines: AsyncIterable[bytes],
|
||||
) -> AsyncIterable[ChatCompletionStreamOutput]:
|
||||
"""Used in `AsyncInferenceClient.chat_completion`."""
|
||||
async for item in bytes_lines:
|
||||
try:
|
||||
output = _format_chat_completion_stream_output(item)
|
||||
except StopIteration:
|
||||
break
|
||||
if output is not None:
|
||||
yield output
|
||||
|
||||
|
||||
def _format_chat_completion_stream_output(
|
||||
byte_payload: bytes,
|
||||
) -> Optional[ChatCompletionStreamOutput]:
|
||||
if not byte_payload.startswith(b"data:"):
|
||||
return None # empty line
|
||||
|
||||
if byte_payload.strip() == b"data: [DONE]":
|
||||
raise StopIteration("[DONE] signal received.")
|
||||
|
||||
# Decode payload
|
||||
payload = byte_payload.decode("utf-8")
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
|
||||
# Either an error as being returned
|
||||
if json_payload.get("error") is not None:
|
||||
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
||||
|
||||
# Or parse token payload
|
||||
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
|
||||
|
||||
|
||||
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
|
||||
async for byte_payload in response.content:
|
||||
yield byte_payload.strip()
|
||||
await client.close()
|
||||
|
||||
|
||||
# "TGI servers" are servers running with the `text-generation-inference` backend.
|
||||
# This backend is the go-to solution to run large language models at scale. However,
|
||||
# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
|
||||
# solution is still in use.
|
||||
#
|
||||
# Both approaches have very similar APIs, but not exactly the same. What we do first in
|
||||
# the `text_generation` method is to assume the model is served via TGI. If we realize
|
||||
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
|
||||
# default API with a warning message. When that's the case, We remember the unsupported
|
||||
# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
|
||||
#
|
||||
# In addition, TGI servers have a built-in API route for chat-completion, which is not
|
||||
# available on the default API. We use this route to provide a more consistent behavior
|
||||
# when available.
|
||||
#
|
||||
# For more details, see https://github.com/huggingface/text-generation-inference and
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
|
||||
|
||||
_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {}
|
||||
|
||||
|
||||
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None:
|
||||
_UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
|
||||
|
||||
|
||||
def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
|
||||
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
|
||||
|
||||
|
||||
# TEXT GENERATION ERRORS
|
||||
# ----------------------
|
||||
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
|
||||
# inference project (https://github.com/huggingface/text-generation-inference).
|
||||
# ----------------------
|
||||
|
||||
|
||||
def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
|
||||
"""
|
||||
Try to parse text-generation-inference error message and raise HTTPError in any case.
|
||||
|
||||
Args:
|
||||
error (`HTTPError`):
|
||||
The HTTPError that have been raised.
|
||||
"""
|
||||
# Try to parse a Text Generation Inference error
|
||||
|
||||
try:
|
||||
# Hacky way to retrieve payload in case of aiohttp error
|
||||
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
|
||||
error = payload.get("error")
|
||||
error_type = payload.get("error_type")
|
||||
except Exception: # no payload
|
||||
raise http_error
|
||||
|
||||
# If error_type => more information than `hf_raise_for_status`
|
||||
if error_type is not None:
|
||||
exception = _parse_text_generation_error(error, error_type)
|
||||
raise exception from http_error
|
||||
|
||||
# Otherwise, fallback to default error
|
||||
raise http_error
|
||||
|
||||
|
||||
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
|
||||
if error_type == "generation":
|
||||
return GenerationError(error) # type: ignore
|
||||
if error_type == "incomplete_generation":
|
||||
return IncompleteGenerationError(error) # type: ignore
|
||||
if error_type == "overloaded":
|
||||
return OverloadedError(error) # type: ignore
|
||||
if error_type == "validation":
|
||||
return ValidationError(error) # type: ignore
|
||||
return UnknownError(error) # type: ignore
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,188 @@
|
||||
# This file is auto-generated by `utils/generate_inference_types.py`.
|
||||
# Do not modify it manually.
|
||||
#
|
||||
# ruff: noqa: F401
|
||||
|
||||
from .audio_classification import (
|
||||
AudioClassificationInput,
|
||||
AudioClassificationOutputElement,
|
||||
AudioClassificationOutputTransform,
|
||||
AudioClassificationParameters,
|
||||
)
|
||||
from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement
|
||||
from .automatic_speech_recognition import (
|
||||
AutomaticSpeechRecognitionEarlyStoppingEnum,
|
||||
AutomaticSpeechRecognitionGenerationParameters,
|
||||
AutomaticSpeechRecognitionInput,
|
||||
AutomaticSpeechRecognitionOutput,
|
||||
AutomaticSpeechRecognitionOutputChunk,
|
||||
AutomaticSpeechRecognitionParameters,
|
||||
)
|
||||
from .base import BaseInferenceType
|
||||
from .chat_completion import (
|
||||
ChatCompletionInput,
|
||||
ChatCompletionInputFunctionDefinition,
|
||||
ChatCompletionInputFunctionName,
|
||||
ChatCompletionInputGrammarType,
|
||||
ChatCompletionInputGrammarTypeType,
|
||||
ChatCompletionInputMessage,
|
||||
ChatCompletionInputMessageChunk,
|
||||
ChatCompletionInputMessageChunkType,
|
||||
ChatCompletionInputStreamOptions,
|
||||
ChatCompletionInputTool,
|
||||
ChatCompletionInputToolCall,
|
||||
ChatCompletionInputToolChoiceClass,
|
||||
ChatCompletionInputToolChoiceEnum,
|
||||
ChatCompletionInputURL,
|
||||
ChatCompletionOutput,
|
||||
ChatCompletionOutputComplete,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
ChatCompletionOutputLogprob,
|
||||
ChatCompletionOutputLogprobs,
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputTopLogprob,
|
||||
ChatCompletionOutputUsage,
|
||||
ChatCompletionStreamOutput,
|
||||
ChatCompletionStreamOutputChoice,
|
||||
ChatCompletionStreamOutputDelta,
|
||||
ChatCompletionStreamOutputDeltaToolCall,
|
||||
ChatCompletionStreamOutputFunction,
|
||||
ChatCompletionStreamOutputLogprob,
|
||||
ChatCompletionStreamOutputLogprobs,
|
||||
ChatCompletionStreamOutputTopLogprob,
|
||||
ChatCompletionStreamOutputUsage,
|
||||
)
|
||||
from .depth_estimation import DepthEstimationInput, DepthEstimationOutput
|
||||
from .document_question_answering import (
|
||||
DocumentQuestionAnsweringInput,
|
||||
DocumentQuestionAnsweringInputData,
|
||||
DocumentQuestionAnsweringOutputElement,
|
||||
DocumentQuestionAnsweringParameters,
|
||||
)
|
||||
from .feature_extraction import FeatureExtractionInput, FeatureExtractionInputTruncationDirection
|
||||
from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters
|
||||
from .image_classification import (
|
||||
ImageClassificationInput,
|
||||
ImageClassificationOutputElement,
|
||||
ImageClassificationOutputTransform,
|
||||
ImageClassificationParameters,
|
||||
)
|
||||
from .image_segmentation import (
|
||||
ImageSegmentationInput,
|
||||
ImageSegmentationOutputElement,
|
||||
ImageSegmentationParameters,
|
||||
ImageSegmentationSubtask,
|
||||
)
|
||||
from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize
|
||||
from .image_to_text import (
|
||||
ImageToTextEarlyStoppingEnum,
|
||||
ImageToTextGenerationParameters,
|
||||
ImageToTextInput,
|
||||
ImageToTextOutput,
|
||||
ImageToTextParameters,
|
||||
)
|
||||
from .object_detection import (
|
||||
ObjectDetectionBoundingBox,
|
||||
ObjectDetectionInput,
|
||||
ObjectDetectionOutputElement,
|
||||
ObjectDetectionParameters,
|
||||
)
|
||||
from .question_answering import (
|
||||
QuestionAnsweringInput,
|
||||
QuestionAnsweringInputData,
|
||||
QuestionAnsweringOutputElement,
|
||||
QuestionAnsweringParameters,
|
||||
)
|
||||
from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData
|
||||
from .summarization import (
|
||||
SummarizationInput,
|
||||
SummarizationOutput,
|
||||
SummarizationParameters,
|
||||
SummarizationTruncationStrategy,
|
||||
)
|
||||
from .table_question_answering import (
|
||||
Padding,
|
||||
TableQuestionAnsweringInput,
|
||||
TableQuestionAnsweringInputData,
|
||||
TableQuestionAnsweringOutputElement,
|
||||
TableQuestionAnsweringParameters,
|
||||
)
|
||||
from .text2text_generation import (
|
||||
Text2TextGenerationInput,
|
||||
Text2TextGenerationOutput,
|
||||
Text2TextGenerationParameters,
|
||||
Text2TextGenerationTruncationStrategy,
|
||||
)
|
||||
from .text_classification import (
|
||||
TextClassificationInput,
|
||||
TextClassificationOutputElement,
|
||||
TextClassificationOutputTransform,
|
||||
TextClassificationParameters,
|
||||
)
|
||||
from .text_generation import (
|
||||
TextGenerationInput,
|
||||
TextGenerationInputGenerateParameters,
|
||||
TextGenerationInputGrammarType,
|
||||
TextGenerationOutput,
|
||||
TextGenerationOutputBestOfSequence,
|
||||
TextGenerationOutputDetails,
|
||||
TextGenerationOutputFinishReason,
|
||||
TextGenerationOutputPrefillToken,
|
||||
TextGenerationOutputToken,
|
||||
TextGenerationStreamOutput,
|
||||
TextGenerationStreamOutputStreamDetails,
|
||||
TextGenerationStreamOutputToken,
|
||||
TypeEnum,
|
||||
)
|
||||
from .text_to_audio import (
|
||||
TextToAudioEarlyStoppingEnum,
|
||||
TextToAudioGenerationParameters,
|
||||
TextToAudioInput,
|
||||
TextToAudioOutput,
|
||||
TextToAudioParameters,
|
||||
)
|
||||
from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters
|
||||
from .text_to_speech import (
|
||||
TextToSpeechEarlyStoppingEnum,
|
||||
TextToSpeechGenerationParameters,
|
||||
TextToSpeechInput,
|
||||
TextToSpeechOutput,
|
||||
TextToSpeechParameters,
|
||||
)
|
||||
from .text_to_video import TextToVideoInput, TextToVideoOutput, TextToVideoParameters
|
||||
from .token_classification import (
|
||||
TokenClassificationAggregationStrategy,
|
||||
TokenClassificationInput,
|
||||
TokenClassificationOutputElement,
|
||||
TokenClassificationParameters,
|
||||
)
|
||||
from .translation import TranslationInput, TranslationOutput, TranslationParameters, TranslationTruncationStrategy
|
||||
from .video_classification import (
|
||||
VideoClassificationInput,
|
||||
VideoClassificationOutputElement,
|
||||
VideoClassificationOutputTransform,
|
||||
VideoClassificationParameters,
|
||||
)
|
||||
from .visual_question_answering import (
|
||||
VisualQuestionAnsweringInput,
|
||||
VisualQuestionAnsweringInputData,
|
||||
VisualQuestionAnsweringOutputElement,
|
||||
VisualQuestionAnsweringParameters,
|
||||
)
|
||||
from .zero_shot_classification import (
|
||||
ZeroShotClassificationInput,
|
||||
ZeroShotClassificationOutputElement,
|
||||
ZeroShotClassificationParameters,
|
||||
)
|
||||
from .zero_shot_image_classification import (
|
||||
ZeroShotImageClassificationInput,
|
||||
ZeroShotImageClassificationOutputElement,
|
||||
ZeroShotImageClassificationParameters,
|
||||
)
|
||||
from .zero_shot_object_detection import (
|
||||
ZeroShotObjectDetectionBoundingBox,
|
||||
ZeroShotObjectDetectionInput,
|
||||
ZeroShotObjectDetectionOutputElement,
|
||||
ZeroShotObjectDetectionParameters,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,43 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AudioClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Audio Classification"""
|
||||
|
||||
function_to_apply: Optional["AudioClassificationOutputTransform"] = None
|
||||
"""The function to apply to the model outputs in order to retrieve the scores."""
|
||||
top_k: Optional[int] = None
|
||||
"""When specified, limits the output to the top K most probable classes."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AudioClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Audio Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input audio data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the audio data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[AudioClassificationParameters] = None
|
||||
"""Additional inference parameters for Audio Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AudioClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs for Audio Classification inference"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,30 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AudioToAudioInput(BaseInferenceType):
|
||||
"""Inputs for Audio to Audio inference"""
|
||||
|
||||
inputs: Any
|
||||
"""The input audio data"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AudioToAudioOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Audio To Audio task
|
||||
A generated audio file with its label.
|
||||
"""
|
||||
|
||||
blob: Any
|
||||
"""The generated audio file."""
|
||||
content_type: str
|
||||
"""The content type of audio file."""
|
||||
label: str
|
||||
"""The label of the audio file."""
|
||||
@@ -0,0 +1,114 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
do_sample: Optional[bool] = None
|
||||
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
||||
early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None
|
||||
"""Controls the stopping condition for beam-based methods."""
|
||||
epsilon_cutoff: Optional[float] = None
|
||||
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
||||
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
||||
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
||||
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
||||
"""
|
||||
eta_cutoff: Optional[float] = None
|
||||
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
||||
float strictly between 0 and 1, a token is only considered if it is greater than either
|
||||
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
||||
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
||||
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
||||
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
||||
for more details.
|
||||
"""
|
||||
max_length: Optional[int] = None
|
||||
"""The maximum length (in tokens) of the generated text, including the input."""
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
||||
min_length: Optional[int] = None
|
||||
"""The minimum length (in tokens) of the generated text, including the input."""
|
||||
min_new_tokens: Optional[int] = None
|
||||
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
||||
num_beam_groups: Optional[int] = None
|
||||
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
||||
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
||||
"""
|
||||
num_beams: Optional[int] = None
|
||||
"""Number of beams to use for beam search."""
|
||||
penalty_alpha: Optional[float] = None
|
||||
"""The value balances the model confidence and the degeneration penalty in contrastive
|
||||
search decoding.
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""The value used to modulate the next token probabilities."""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
||||
top_p: Optional[float] = None
|
||||
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to top_p or higher are kept for generation.
|
||||
"""
|
||||
typical_p: Optional[float] = None
|
||||
"""Local typicality measures how similar the conditional probability of predicting a target
|
||||
token next is to the expected conditional probability of predicting a random token next,
|
||||
given the partial text already generated. If set to float < 1, the smallest set of the
|
||||
most locally typical tokens with probabilities that add up to typical_p or higher are
|
||||
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
||||
"""
|
||||
use_cache: Optional[bool] = None
|
||||
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AutomaticSpeechRecognitionParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Automatic Speech Recognition"""
|
||||
|
||||
return_timestamps: Optional[bool] = None
|
||||
"""Whether to output corresponding timestamps with the generated text"""
|
||||
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
||||
generate_kwargs: Optional[AutomaticSpeechRecognitionGenerationParameters] = None
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AutomaticSpeechRecognitionInput(BaseInferenceType):
|
||||
"""Inputs for Automatic Speech Recognition inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input audio data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the audio data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[AutomaticSpeechRecognitionParameters] = None
|
||||
"""Additional inference parameters for Automatic Speech Recognition"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType):
|
||||
text: str
|
||||
"""A chunk of text identified by the model"""
|
||||
timestamp: List[float]
|
||||
"""The start and end timestamps corresponding with the text"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class AutomaticSpeechRecognitionOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Automatic Speech Recognition task"""
|
||||
|
||||
text: str
|
||||
"""The recognized text."""
|
||||
chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None
|
||||
"""When returnTimestamps is enabled, chunks contains a list of audio chunks identified by
|
||||
the model.
|
||||
"""
|
||||
@@ -0,0 +1,161 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains a base class for all inference types."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Type, TypeVar, Union, get_args
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseInferenceType")
|
||||
|
||||
|
||||
def _repr_with_extra(self):
|
||||
fields = list(self.__dataclass_fields__.keys())
|
||||
other_fields = list(k for k in self.__dict__ if k not in fields)
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})"
|
||||
|
||||
|
||||
def dataclass_with_extra(cls: Type[T]) -> Type[T]:
|
||||
"""Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones.
|
||||
|
||||
This decorator only works with dataclasses that inherit from `BaseInferenceType`.
|
||||
"""
|
||||
cls = dataclass(cls)
|
||||
cls.__repr__ = _repr_with_extra # type: ignore[method-assign]
|
||||
return cls
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseInferenceType(dict):
|
||||
"""Base class for all inference types.
|
||||
|
||||
Object is a dataclass and a dict for backward compatibility but plan is to remove the dict part in the future.
|
||||
|
||||
Handle parsing from dict, list and json strings in a permissive way to ensure future-compatibility (e.g. all fields
|
||||
are made optional, and non-expected fields are added as dict attributes).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List[T]:
|
||||
"""Alias to parse server response and return a single instance.
|
||||
|
||||
See `parse_obj` for more details.
|
||||
"""
|
||||
output = cls.parse_obj(data)
|
||||
if not isinstance(output, list):
|
||||
raise ValueError(f"Invalid input data for {cls}. Expected a list, but got {type(output)}.")
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> T:
|
||||
"""Alias to parse server response and return a single instance.
|
||||
|
||||
See `parse_obj` for more details.
|
||||
"""
|
||||
output = cls.parse_obj(data)
|
||||
if isinstance(output, list):
|
||||
raise ValueError(f"Invalid input data for {cls}. Expected a single instance, but got a list.")
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T], T]:
|
||||
"""Parse server response as a dataclass or list of dataclasses.
|
||||
|
||||
To enable future-compatibility, we want to handle cases where the server return more fields than expected.
|
||||
In such cases, we don't want to raise an error but still create the dataclass object. Remaining fields are
|
||||
added as dict attributes.
|
||||
"""
|
||||
# Parse server response (from bytes)
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode()
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
|
||||
# If a list, parse each item individually
|
||||
if isinstance(data, List):
|
||||
return [cls.parse_obj(d) for d in data] # type: ignore [misc]
|
||||
|
||||
# At this point, we expect a dict
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Invalid data type: {type(data)}")
|
||||
|
||||
init_values = {}
|
||||
other_values = {}
|
||||
for key, value in data.items():
|
||||
key = normalize_key(key)
|
||||
if key in cls.__dataclass_fields__ and cls.__dataclass_fields__[key].init:
|
||||
if isinstance(value, dict) or isinstance(value, list):
|
||||
field_type = cls.__dataclass_fields__[key].type
|
||||
|
||||
# if `field_type` is a `BaseInferenceType`, parse it
|
||||
if inspect.isclass(field_type) and issubclass(field_type, BaseInferenceType):
|
||||
value = field_type.parse_obj(value)
|
||||
|
||||
# otherwise, recursively parse nested dataclasses (if possible)
|
||||
# `get_args` returns handle Union and Optional for us
|
||||
else:
|
||||
expected_types = get_args(field_type)
|
||||
for expected_type in expected_types:
|
||||
if getattr(expected_type, "_name", None) == "List":
|
||||
expected_type = get_args(expected_type)[
|
||||
0
|
||||
] # assume same type for all items in the list
|
||||
if inspect.isclass(expected_type) and issubclass(expected_type, BaseInferenceType):
|
||||
value = expected_type.parse_obj(value)
|
||||
break
|
||||
init_values[key] = value
|
||||
else:
|
||||
other_values[key] = value
|
||||
|
||||
# Make all missing fields default to None
|
||||
# => ensure that dataclass initialization will never fail even if the server does not return all fields.
|
||||
for key in cls.__dataclass_fields__:
|
||||
if key not in init_values:
|
||||
init_values[key] = None
|
||||
|
||||
# Initialize dataclass with expected values
|
||||
item = cls(**init_values)
|
||||
|
||||
# Add remaining fields as dict attributes
|
||||
item.update(other_values)
|
||||
|
||||
# Add remaining fields as extra dataclass fields.
|
||||
# They won't be part of the dataclass fields but will be accessible as attributes.
|
||||
# Use @dataclass_with_extra to show them in __repr__.
|
||||
item.__dict__.update(other_values)
|
||||
return item
|
||||
|
||||
def __post_init__(self):
|
||||
self.update(asdict(self))
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
# Hacky way to keep dataclass values in sync when dict is updated
|
||||
super().__setitem__(__key, __value)
|
||||
if __key in self.__dataclass_fields__ and getattr(self, __key, None) != __value:
|
||||
self.__setattr__(__key, __value)
|
||||
return
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
# Hacky way to keep dict values is sync when dataclass is updated
|
||||
super().__setattr__(__name, __value)
|
||||
if self.get(__name) != __value:
|
||||
self[__name] = __value
|
||||
return
|
||||
|
||||
|
||||
def normalize_key(key: str) -> str:
|
||||
# e.g "content-type" -> "content_type", "Accept" -> "accept"
|
||||
return key.replace("-", "_").replace(" ", "_").lower()
|
||||
@@ -0,0 +1,311 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputURL(BaseInferenceType):
|
||||
url: str
|
||||
|
||||
|
||||
ChatCompletionInputMessageChunkType = Literal["text", "image_url"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputMessageChunk(BaseInferenceType):
|
||||
type: "ChatCompletionInputMessageChunkType"
|
||||
image_url: Optional[ChatCompletionInputURL] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
|
||||
arguments: Any
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputToolCall(BaseInferenceType):
|
||||
function: ChatCompletionInputFunctionDefinition
|
||||
id: str
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputMessage(BaseInferenceType):
|
||||
role: str
|
||||
content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ChatCompletionInputToolCall]] = None
|
||||
|
||||
|
||||
ChatCompletionInputGrammarTypeType = Literal["json", "regex"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputGrammarType(BaseInferenceType):
|
||||
type: "ChatCompletionInputGrammarTypeType"
|
||||
value: Any
|
||||
"""A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
with types and descriptions.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputStreamOptions(BaseInferenceType):
|
||||
include_usage: Optional[bool] = None
|
||||
"""If set, an additional chunk will be streamed before the data: [DONE] message. The usage
|
||||
field on this chunk shows the token usage statistics for the entire request, and the
|
||||
choices field will always be an empty array. All other chunks will also include a usage
|
||||
field, but with a null value.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputFunctionName(BaseInferenceType):
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputToolChoiceClass(BaseInferenceType):
|
||||
function: ChatCompletionInputFunctionName
|
||||
|
||||
|
||||
ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInputTool(BaseInferenceType):
|
||||
function: ChatCompletionInputFunctionDefinition
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionInput(BaseInferenceType):
|
||||
"""Chat Completion Input.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
messages: List[ChatCompletionInputMessage]
|
||||
"""A list of messages comprising the conversation so far."""
|
||||
frequency_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
|
||||
frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
"""
|
||||
logit_bias: Optional[List[float]] = None
|
||||
"""UNUSED
|
||||
Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON
|
||||
object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated bias value from -100 to
|
||||
100. Mathematically,
|
||||
the bias is added to the logits generated by the model prior to sampling. The exact
|
||||
effect will vary per model,
|
||||
but values between -1 and 1 should decrease or increase likelihood of selection; values
|
||||
like -100 or 100 should
|
||||
result in a ban or exclusive selection of the relevant token.
|
||||
"""
|
||||
logprobs: Optional[bool] = None
|
||||
"""Whether to return log probabilities of the output tokens or not. If true, returns the log
|
||||
probabilities of each
|
||||
output token returned in the content of message.
|
||||
"""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens that can be generated in the chat completion."""
|
||||
model: Optional[str] = None
|
||||
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
|
||||
on which models work with the Chat API.
|
||||
"""
|
||||
n: Optional[int] = None
|
||||
"""UNUSED
|
||||
How many chat completion choices to generate for each input message. Note that you will
|
||||
be charged based on the
|
||||
number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||
"""
|
||||
presence_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
|
||||
appear in the text so far,
|
||||
increasing the model's likelihood to talk about new topics
|
||||
"""
|
||||
response_format: Optional[ChatCompletionInputGrammarType] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
"""Up to 4 sequences where the API will stop generating further tokens."""
|
||||
stream: Optional[bool] = None
|
||||
stream_options: Optional[ChatCompletionInputStreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the
|
||||
output more random, while
|
||||
lower values like 0.2 will make it more focused and deterministic.
|
||||
We generally recommend altering this or `top_p` but not both.
|
||||
"""
|
||||
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
|
||||
tool_prompt: Optional[str] = None
|
||||
"""A prompt to be appended before the tools"""
|
||||
tools: Optional[List[ChatCompletionInputTool]] = None
|
||||
"""A list of tools the model may call. Currently, only functions are supported as a tool.
|
||||
Use this to provide a list of
|
||||
functions the model may generate JSON inputs for.
|
||||
"""
|
||||
top_logprobs: Optional[int] = None
|
||||
"""An integer between 0 and 5 specifying the number of most likely tokens to return at each
|
||||
token position, each with
|
||||
an associated log probability. logprobs must be set to true if this parameter is used.
|
||||
"""
|
||||
top_p: Optional[float] = None
|
||||
"""An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||
considers the results of the
|
||||
tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
|
||||
probability mass are considered.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputTopLogprob(BaseInferenceType):
|
||||
logprob: float
|
||||
token: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputLogprob(BaseInferenceType):
|
||||
logprob: float
|
||||
token: str
|
||||
top_logprobs: List[ChatCompletionOutputTopLogprob]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputLogprobs(BaseInferenceType):
|
||||
content: List[ChatCompletionOutputLogprob]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
|
||||
arguments: Any
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputToolCall(BaseInferenceType):
|
||||
function: ChatCompletionOutputFunctionDefinition
|
||||
id: str
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputMessage(BaseInferenceType):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputComplete(BaseInferenceType):
|
||||
finish_reason: str
|
||||
index: int
|
||||
message: ChatCompletionOutputMessage
|
||||
logprobs: Optional[ChatCompletionOutputLogprobs] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutputUsage(BaseInferenceType):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionOutput(BaseInferenceType):
|
||||
"""Chat Completion Output.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
choices: List[ChatCompletionOutputComplete]
|
||||
created: int
|
||||
id: str
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
usage: ChatCompletionOutputUsage
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputFunction(BaseInferenceType):
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
|
||||
function: ChatCompletionStreamOutputFunction
|
||||
id: str
|
||||
index: int
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputDelta(BaseInferenceType):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputTopLogprob(BaseInferenceType):
|
||||
logprob: float
|
||||
token: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputLogprob(BaseInferenceType):
|
||||
logprob: float
|
||||
token: str
|
||||
top_logprobs: List[ChatCompletionStreamOutputTopLogprob]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputLogprobs(BaseInferenceType):
|
||||
content: List[ChatCompletionStreamOutputLogprob]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputChoice(BaseInferenceType):
|
||||
delta: ChatCompletionStreamOutputDelta
|
||||
index: int
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutputUsage(BaseInferenceType):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ChatCompletionStreamOutput(BaseInferenceType):
|
||||
"""Chat Completion Stream Output.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
choices: List[ChatCompletionStreamOutputChoice]
|
||||
created: int
|
||||
id: str
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
usage: Optional[ChatCompletionStreamOutputUsage] = None
|
||||
@@ -0,0 +1,28 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DepthEstimationInput(BaseInferenceType):
|
||||
"""Inputs for Depth Estimation inference"""
|
||||
|
||||
inputs: Any
|
||||
"""The input image data"""
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
"""Additional inference parameters for Depth Estimation"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DepthEstimationOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Depth Estimation task"""
|
||||
|
||||
depth: Any
|
||||
"""The predicted depth as an image"""
|
||||
predicted_depth: Any
|
||||
"""The predicted depth as a tensor"""
|
||||
@@ -0,0 +1,80 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DocumentQuestionAnsweringInputData(BaseInferenceType):
|
||||
"""One (document, question) pair to answer"""
|
||||
|
||||
image: Any
|
||||
"""The image on which the question is asked"""
|
||||
question: str
|
||||
"""A question to ask of the document"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DocumentQuestionAnsweringParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Document Question Answering"""
|
||||
|
||||
doc_stride: Optional[int] = None
|
||||
"""If the words in the document are too long to fit with the question for the model, it will
|
||||
be split in several chunks with some overlap. This argument controls the size of that
|
||||
overlap.
|
||||
"""
|
||||
handle_impossible_answer: Optional[bool] = None
|
||||
"""Whether to accept impossible as an answer"""
|
||||
lang: Optional[str] = None
|
||||
"""Language to use while running OCR. Defaults to english."""
|
||||
max_answer_len: Optional[int] = None
|
||||
"""The maximum length of predicted answers (e.g., only answers with a shorter length are
|
||||
considered).
|
||||
"""
|
||||
max_question_len: Optional[int] = None
|
||||
"""The maximum length of the question after tokenization. It will be truncated if needed."""
|
||||
max_seq_len: Optional[int] = None
|
||||
"""The maximum length of the total sentence (context + question) in tokens of each chunk
|
||||
passed to the model. The context will be split in several chunks (using doc_stride as
|
||||
overlap) if needed.
|
||||
"""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of answers to return (will be chosen by order of likelihood). Can return less
|
||||
than top_k answers if there are not enough options available within the context.
|
||||
"""
|
||||
word_boxes: Optional[List[Union[List[float], str]]] = None
|
||||
"""A list of words and bounding boxes (normalized 0->1000). If provided, the inference will
|
||||
skip the OCR step and use the provided bounding boxes instead.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DocumentQuestionAnsweringInput(BaseInferenceType):
|
||||
"""Inputs for Document Question Answering inference"""
|
||||
|
||||
inputs: DocumentQuestionAnsweringInputData
|
||||
"""One (document, question) pair to answer"""
|
||||
parameters: Optional[DocumentQuestionAnsweringParameters] = None
|
||||
"""Additional inference parameters for Document Question Answering"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class DocumentQuestionAnsweringOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Document Question Answering task"""
|
||||
|
||||
answer: str
|
||||
"""The answer to the question."""
|
||||
end: int
|
||||
"""The end word index of the answer (in the OCR’d version of the input or provided word
|
||||
boxes).
|
||||
"""
|
||||
score: float
|
||||
"""The probability associated to the answer."""
|
||||
start: int
|
||||
"""The start word index of the answer (in the OCR’d version of the input or provided word
|
||||
boxes).
|
||||
"""
|
||||
@@ -0,0 +1,36 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class FeatureExtractionInput(BaseInferenceType):
|
||||
"""Feature Extraction Input.
|
||||
Auto-generated from TEI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
|
||||
"""
|
||||
|
||||
inputs: Union[List[str], str]
|
||||
"""The text or list of texts to embed."""
|
||||
normalize: Optional[bool] = None
|
||||
prompt_name: Optional[str] = None
|
||||
"""The name of the prompt that should be used by for encoding. If not set, no prompt
|
||||
will be applied.
|
||||
Must be a key in the `sentence-transformers` configuration `prompts` dictionary.
|
||||
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
|
||||
...},
|
||||
then the sentence "What is the capital of France?" will be encoded as
|
||||
"query: What is the capital of France?" because the prompt text will be prepended before
|
||||
any text to encode.
|
||||
"""
|
||||
truncate: Optional[bool] = None
|
||||
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None
|
||||
@@ -0,0 +1,47 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class FillMaskParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Fill Mask"""
|
||||
|
||||
targets: Optional[List[str]] = None
|
||||
"""When passed, the model will limit the scores to the passed targets instead of looking up
|
||||
in the whole vocabulary. If the provided targets are not in the model vocab, they will be
|
||||
tokenized and the first resulting token will be used (with a warning, and that might be
|
||||
slower).
|
||||
"""
|
||||
top_k: Optional[int] = None
|
||||
"""When passed, overrides the number of predictions to return."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class FillMaskInput(BaseInferenceType):
|
||||
"""Inputs for Fill Mask inference"""
|
||||
|
||||
inputs: str
|
||||
"""The text with masked tokens"""
|
||||
parameters: Optional[FillMaskParameters] = None
|
||||
"""Additional inference parameters for Fill Mask"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class FillMaskOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Fill Mask task"""
|
||||
|
||||
score: float
|
||||
"""The corresponding probability"""
|
||||
sequence: str
|
||||
"""The corresponding input with the mask token prediction."""
|
||||
token: int
|
||||
"""The predicted token id (to replace the masked one)."""
|
||||
token_str: Any
|
||||
fill_mask_output_token_str: Optional[str] = None
|
||||
"""The predicted token (to replace the masked one)."""
|
||||
@@ -0,0 +1,43 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
ImageClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Image Classification"""
|
||||
|
||||
function_to_apply: Optional["ImageClassificationOutputTransform"] = None
|
||||
"""The function to apply to the model outputs in order to retrieve the scores."""
|
||||
top_k: Optional[int] = None
|
||||
"""When specified, limits the output to the top K most probable classes."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Image Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the image data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[ImageClassificationParameters] = None
|
||||
"""Additional inference parameters for Image Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Image Classification task"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,51 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageSegmentationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Image Segmentation"""
|
||||
|
||||
mask_threshold: Optional[float] = None
|
||||
"""Threshold to use when turning the predicted masks into binary values."""
|
||||
overlap_mask_area_threshold: Optional[float] = None
|
||||
"""Mask overlap threshold to eliminate small, disconnected segments."""
|
||||
subtask: Optional["ImageSegmentationSubtask"] = None
|
||||
"""Segmentation task to be performed, depending on model capabilities."""
|
||||
threshold: Optional[float] = None
|
||||
"""Probability threshold to filter out predicted masks."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageSegmentationInput(BaseInferenceType):
|
||||
"""Inputs for Image Segmentation inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the image data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[ImageSegmentationParameters] = None
|
||||
"""Additional inference parameters for Image Segmentation"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageSegmentationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Image Segmentation task
|
||||
A predicted mask / segment
|
||||
"""
|
||||
|
||||
label: str
|
||||
"""The label of the predicted segment."""
|
||||
mask: str
|
||||
"""The corresponding mask as a black-and-white image (base64-encoded)."""
|
||||
score: Optional[float] = None
|
||||
"""The score or confidence degree the model has."""
|
||||
@@ -0,0 +1,56 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToImageTargetSize(BaseInferenceType):
|
||||
"""The size in pixel of the output image."""
|
||||
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToImageParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Image To Image"""
|
||||
|
||||
guidance_scale: Optional[float] = None
|
||||
"""For diffusion models. A higher guidance scale value encourages the model to generate
|
||||
images closely linked to the text prompt at the expense of lower image quality.
|
||||
"""
|
||||
negative_prompt: Optional[str] = None
|
||||
"""One prompt to guide what NOT to include in image generation."""
|
||||
num_inference_steps: Optional[int] = None
|
||||
"""For diffusion models. The number of denoising steps. More denoising steps usually lead to
|
||||
a higher quality image at the expense of slower inference.
|
||||
"""
|
||||
prompt: Optional[str] = None
|
||||
"""The text prompt to guide the image generation."""
|
||||
target_size: Optional[ImageToImageTargetSize] = None
|
||||
"""The size in pixel of the output image."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToImageInput(BaseInferenceType):
|
||||
"""Inputs for Image To Image inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the image data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[ImageToImageParameters] = None
|
||||
"""Additional inference parameters for Image To Image"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToImageOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Image To Image task"""
|
||||
|
||||
image: Any
|
||||
"""The output image returned as raw bytes in the payload."""
|
||||
@@ -0,0 +1,101 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
ImageToTextEarlyStoppingEnum = Literal["never"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToTextGenerationParameters(BaseInferenceType):
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
do_sample: Optional[bool] = None
|
||||
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
||||
early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None
|
||||
"""Controls the stopping condition for beam-based methods."""
|
||||
epsilon_cutoff: Optional[float] = None
|
||||
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
||||
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
||||
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
||||
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
||||
"""
|
||||
eta_cutoff: Optional[float] = None
|
||||
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
||||
float strictly between 0 and 1, a token is only considered if it is greater than either
|
||||
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
||||
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
||||
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
||||
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
||||
for more details.
|
||||
"""
|
||||
max_length: Optional[int] = None
|
||||
"""The maximum length (in tokens) of the generated text, including the input."""
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
||||
min_length: Optional[int] = None
|
||||
"""The minimum length (in tokens) of the generated text, including the input."""
|
||||
min_new_tokens: Optional[int] = None
|
||||
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
||||
num_beam_groups: Optional[int] = None
|
||||
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
||||
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
||||
"""
|
||||
num_beams: Optional[int] = None
|
||||
"""Number of beams to use for beam search."""
|
||||
penalty_alpha: Optional[float] = None
|
||||
"""The value balances the model confidence and the degeneration penalty in contrastive
|
||||
search decoding.
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""The value used to modulate the next token probabilities."""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
||||
top_p: Optional[float] = None
|
||||
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to top_p or higher are kept for generation.
|
||||
"""
|
||||
typical_p: Optional[float] = None
|
||||
"""Local typicality measures how similar the conditional probability of predicting a target
|
||||
token next is to the expected conditional probability of predicting a random token next,
|
||||
given the partial text already generated. If set to float < 1, the smallest set of the
|
||||
most locally typical tokens with probabilities that add up to typical_p or higher are
|
||||
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
||||
"""
|
||||
use_cache: Optional[bool] = None
|
||||
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToTextParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Image To Text"""
|
||||
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The amount of maximum tokens to generate."""
|
||||
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
||||
generate_kwargs: Optional[ImageToTextGenerationParameters] = None
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToTextInput(BaseInferenceType):
|
||||
"""Inputs for Image To Text inference"""
|
||||
|
||||
inputs: Any
|
||||
"""The input image data"""
|
||||
parameters: Optional[ImageToTextParameters] = None
|
||||
"""Additional inference parameters for Image To Text"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ImageToTextOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Image To Text task"""
|
||||
|
||||
generated_text: Any
|
||||
image_to_text_output_generated_text: Optional[str] = None
|
||||
"""The generated text."""
|
||||
@@ -0,0 +1,58 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ObjectDetectionParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Object Detection"""
|
||||
|
||||
threshold: Optional[float] = None
|
||||
"""The probability necessary to make a prediction."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ObjectDetectionInput(BaseInferenceType):
|
||||
"""Inputs for Object Detection inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
||||
also provide the image data as a raw bytes payload.
|
||||
"""
|
||||
parameters: Optional[ObjectDetectionParameters] = None
|
||||
"""Additional inference parameters for Object Detection"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ObjectDetectionBoundingBox(BaseInferenceType):
|
||||
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
||||
image.
|
||||
"""
|
||||
|
||||
xmax: int
|
||||
"""The x-coordinate of the bottom-right corner of the bounding box."""
|
||||
xmin: int
|
||||
"""The x-coordinate of the top-left corner of the bounding box."""
|
||||
ymax: int
|
||||
"""The y-coordinate of the bottom-right corner of the bounding box."""
|
||||
ymin: int
|
||||
"""The y-coordinate of the top-left corner of the bounding box."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ObjectDetectionOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Object Detection task"""
|
||||
|
||||
box: ObjectDetectionBoundingBox
|
||||
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
||||
image.
|
||||
"""
|
||||
label: str
|
||||
"""The predicted label for the bounding box."""
|
||||
score: float
|
||||
"""The associated score / probability."""
|
||||
@@ -0,0 +1,74 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class QuestionAnsweringInputData(BaseInferenceType):
|
||||
"""One (context, question) pair to answer"""
|
||||
|
||||
context: str
|
||||
"""The context to be used for answering the question"""
|
||||
question: str
|
||||
"""The question to be answered"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class QuestionAnsweringParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Question Answering"""
|
||||
|
||||
align_to_words: Optional[bool] = None
|
||||
"""Attempts to align the answer to real words. Improves quality on space separated
|
||||
languages. Might hurt on non-space-separated languages (like Japanese or Chinese)
|
||||
"""
|
||||
doc_stride: Optional[int] = None
|
||||
"""If the context is too long to fit with the question for the model, it will be split in
|
||||
several chunks with some overlap. This argument controls the size of that overlap.
|
||||
"""
|
||||
handle_impossible_answer: Optional[bool] = None
|
||||
"""Whether to accept impossible as an answer."""
|
||||
max_answer_len: Optional[int] = None
|
||||
"""The maximum length of predicted answers (e.g., only answers with a shorter length are
|
||||
considered).
|
||||
"""
|
||||
max_question_len: Optional[int] = None
|
||||
"""The maximum length of the question after tokenization. It will be truncated if needed."""
|
||||
max_seq_len: Optional[int] = None
|
||||
"""The maximum length of the total sentence (context + question) in tokens of each chunk
|
||||
passed to the model. The context will be split in several chunks (using docStride as
|
||||
overlap) if needed.
|
||||
"""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of answers to return (will be chosen by order of likelihood). Note that we
|
||||
return less than topk answers if there are not enough options available within the
|
||||
context.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class QuestionAnsweringInput(BaseInferenceType):
|
||||
"""Inputs for Question Answering inference"""
|
||||
|
||||
inputs: QuestionAnsweringInputData
|
||||
"""One (context, question) pair to answer"""
|
||||
parameters: Optional[QuestionAnsweringParameters] = None
|
||||
"""Additional inference parameters for Question Answering"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class QuestionAnsweringOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Question Answering task"""
|
||||
|
||||
answer: str
|
||||
"""The answer to the question."""
|
||||
end: int
|
||||
"""The character position in the input where the answer ends."""
|
||||
score: float
|
||||
"""The probability associated to the answer."""
|
||||
start: int
|
||||
"""The character position in the input where the answer begins."""
|
||||
@@ -0,0 +1,27 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class SentenceSimilarityInputData(BaseInferenceType):
|
||||
sentences: List[str]
|
||||
"""A list of strings which will be compared against the source_sentence."""
|
||||
source_sentence: str
|
||||
"""The string that you wish to compare the other strings with. This can be a phrase,
|
||||
sentence, or longer passage, depending on the model being used.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class SentenceSimilarityInput(BaseInferenceType):
|
||||
"""Inputs for Sentence similarity inference"""
|
||||
|
||||
inputs: SentenceSimilarityInputData
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
"""Additional inference parameters for Sentence Similarity"""
|
||||
@@ -0,0 +1,41 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
SummarizationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class SummarizationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for summarization."""
|
||||
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
"""Whether to clean up the potential extra spaces in the text output."""
|
||||
generate_parameters: Optional[Dict[str, Any]] = None
|
||||
"""Additional parametrization of the text generation algorithm."""
|
||||
truncation: Optional["SummarizationTruncationStrategy"] = None
|
||||
"""The truncation strategy to use."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class SummarizationInput(BaseInferenceType):
|
||||
"""Inputs for Summarization inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text to summarize."""
|
||||
parameters: Optional[SummarizationParameters] = None
|
||||
"""Additional inference parameters for summarization."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class SummarizationOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Summarization task"""
|
||||
|
||||
summary_text: str
|
||||
"""The summarized text."""
|
||||
@@ -0,0 +1,62 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TableQuestionAnsweringInputData(BaseInferenceType):
|
||||
"""One (table, question) pair to answer"""
|
||||
|
||||
question: str
|
||||
"""The question to be answered about the table"""
|
||||
table: Dict[str, List[str]]
|
||||
"""The table to serve as context for the questions"""
|
||||
|
||||
|
||||
Padding = Literal["do_not_pad", "longest", "max_length"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TableQuestionAnsweringParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Table Question Answering"""
|
||||
|
||||
padding: Optional["Padding"] = None
|
||||
"""Activates and controls padding."""
|
||||
sequential: Optional[bool] = None
|
||||
"""Whether to do inference sequentially or as a batch. Batching is faster, but models like
|
||||
SQA require the inference to be done sequentially to extract relations within sequences,
|
||||
given their conversational nature.
|
||||
"""
|
||||
truncation: Optional[bool] = None
|
||||
"""Activates and controls truncation."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TableQuestionAnsweringInput(BaseInferenceType):
|
||||
"""Inputs for Table Question Answering inference"""
|
||||
|
||||
inputs: TableQuestionAnsweringInputData
|
||||
"""One (table, question) pair to answer"""
|
||||
parameters: Optional[TableQuestionAnsweringParameters] = None
|
||||
"""Additional inference parameters for Table Question Answering"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TableQuestionAnsweringOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Table Question Answering task"""
|
||||
|
||||
answer: str
|
||||
"""The answer of the question given the table. If there is an aggregator, the answer will be
|
||||
preceded by `AGGREGATOR >`.
|
||||
"""
|
||||
cells: List[str]
|
||||
"""List of strings made up of the answer cell values."""
|
||||
coordinates: List[List[int]]
|
||||
"""Coordinates of the cells of the answers."""
|
||||
aggregator: Optional[str] = None
|
||||
"""If the model has an aggregator, this returns the aggregator."""
|
||||
@@ -0,0 +1,42 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
Text2TextGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class Text2TextGenerationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text2text Generation"""
|
||||
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
"""Whether to clean up the potential extra spaces in the text output."""
|
||||
generate_parameters: Optional[Dict[str, Any]] = None
|
||||
"""Additional parametrization of the text generation algorithm"""
|
||||
truncation: Optional["Text2TextGenerationTruncationStrategy"] = None
|
||||
"""The truncation strategy to use"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class Text2TextGenerationInput(BaseInferenceType):
|
||||
"""Inputs for Text2text Generation inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data"""
|
||||
parameters: Optional[Text2TextGenerationParameters] = None
|
||||
"""Additional inference parameters for Text2text Generation"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class Text2TextGenerationOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Text2text Generation task"""
|
||||
|
||||
generated_text: Any
|
||||
text2_text_generation_output_generated_text: Optional[str] = None
|
||||
"""The generated text."""
|
||||
@@ -0,0 +1,41 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TextClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text Classification"""
|
||||
|
||||
function_to_apply: Optional["TextClassificationOutputTransform"] = None
|
||||
"""The function to apply to the model outputs in order to retrieve the scores."""
|
||||
top_k: Optional[int] = None
|
||||
"""When specified, limits the output to the top K most probable classes."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Text Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The text to classify"""
|
||||
parameters: Optional[TextClassificationParameters] = None
|
||||
"""Additional inference parameters for Text Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Text Classification task"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,168 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TypeEnum = Literal["json", "regex"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationInputGrammarType(BaseInferenceType):
|
||||
type: "TypeEnum"
|
||||
value: Any
|
||||
"""A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
with types and descriptions.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationInputGenerateParameters(BaseInferenceType):
|
||||
adapter_id: Optional[str] = None
|
||||
"""Lora adapter id"""
|
||||
best_of: Optional[int] = None
|
||||
"""Generate best_of sequences and return the one if the highest token logprobs."""
|
||||
decoder_input_details: Optional[bool] = None
|
||||
"""Whether to return decoder input token logprobs and ids."""
|
||||
details: Optional[bool] = None
|
||||
"""Whether to return generation details."""
|
||||
do_sample: Optional[bool] = None
|
||||
"""Activate logits sampling."""
|
||||
frequency_penalty: Optional[float] = None
|
||||
"""The parameter for frequency penalty. 1.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
"""
|
||||
grammar: Optional[TextGenerationInputGrammarType] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
"""
|
||||
return_full_text: Optional[bool] = None
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""Stop generating tokens if a member of `stop` is generated."""
|
||||
temperature: Optional[float] = None
|
||||
"""The value used to module the logits distribution."""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
||||
top_n_tokens: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-n-filtering."""
|
||||
top_p: Optional[float] = None
|
||||
"""Top-p value for nucleus sampling."""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size."""
|
||||
typical_p: Optional[float] = None
|
||||
"""Typical Decoding mass
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
|
||||
for more information.
|
||||
"""
|
||||
watermark: Optional[bool] = None
|
||||
"""Watermarking with [A Watermark for Large Language
|
||||
Models](https://arxiv.org/abs/2301.10226).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationInput(BaseInferenceType):
|
||||
"""Text Generation Input.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
inputs: str
|
||||
parameters: Optional[TextGenerationInputGenerateParameters] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
|
||||
TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationOutputPrefillToken(BaseInferenceType):
|
||||
id: int
|
||||
logprob: float
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationOutputToken(BaseInferenceType):
|
||||
id: int
|
||||
logprob: float
|
||||
special: bool
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationOutputBestOfSequence(BaseInferenceType):
|
||||
finish_reason: "TextGenerationOutputFinishReason"
|
||||
generated_text: str
|
||||
generated_tokens: int
|
||||
prefill: List[TextGenerationOutputPrefillToken]
|
||||
tokens: List[TextGenerationOutputToken]
|
||||
seed: Optional[int] = None
|
||||
top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationOutputDetails(BaseInferenceType):
|
||||
finish_reason: "TextGenerationOutputFinishReason"
|
||||
generated_tokens: int
|
||||
prefill: List[TextGenerationOutputPrefillToken]
|
||||
tokens: List[TextGenerationOutputToken]
|
||||
best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None
|
||||
seed: Optional[int] = None
|
||||
top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationOutput(BaseInferenceType):
|
||||
"""Text Generation Output.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
generated_text: str
|
||||
details: Optional[TextGenerationOutputDetails] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationStreamOutputStreamDetails(BaseInferenceType):
|
||||
finish_reason: "TextGenerationOutputFinishReason"
|
||||
generated_tokens: int
|
||||
input_length: int
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationStreamOutputToken(BaseInferenceType):
|
||||
id: int
|
||||
logprob: float
|
||||
special: bool
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextGenerationStreamOutput(BaseInferenceType):
|
||||
"""Text Generation Stream Output.
|
||||
Auto-generated from TGI specs.
|
||||
For more details, check out
|
||||
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
|
||||
"""
|
||||
|
||||
index: int
|
||||
token: TextGenerationStreamOutputToken
|
||||
details: Optional[TextGenerationStreamOutputStreamDetails] = None
|
||||
generated_text: Optional[str] = None
|
||||
top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None
|
||||
@@ -0,0 +1,100 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TextToAudioEarlyStoppingEnum = Literal["never"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToAudioGenerationParameters(BaseInferenceType):
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
do_sample: Optional[bool] = None
|
||||
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
||||
early_stopping: Optional[Union[bool, "TextToAudioEarlyStoppingEnum"]] = None
|
||||
"""Controls the stopping condition for beam-based methods."""
|
||||
epsilon_cutoff: Optional[float] = None
|
||||
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
||||
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
||||
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
||||
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
||||
"""
|
||||
eta_cutoff: Optional[float] = None
|
||||
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
||||
float strictly between 0 and 1, a token is only considered if it is greater than either
|
||||
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
||||
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
||||
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
||||
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
||||
for more details.
|
||||
"""
|
||||
max_length: Optional[int] = None
|
||||
"""The maximum length (in tokens) of the generated text, including the input."""
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
||||
min_length: Optional[int] = None
|
||||
"""The minimum length (in tokens) of the generated text, including the input."""
|
||||
min_new_tokens: Optional[int] = None
|
||||
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
||||
num_beam_groups: Optional[int] = None
|
||||
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
||||
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
||||
"""
|
||||
num_beams: Optional[int] = None
|
||||
"""Number of beams to use for beam search."""
|
||||
penalty_alpha: Optional[float] = None
|
||||
"""The value balances the model confidence and the degeneration penalty in contrastive
|
||||
search decoding.
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""The value used to modulate the next token probabilities."""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
||||
top_p: Optional[float] = None
|
||||
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to top_p or higher are kept for generation.
|
||||
"""
|
||||
typical_p: Optional[float] = None
|
||||
"""Local typicality measures how similar the conditional probability of predicting a target
|
||||
token next is to the expected conditional probability of predicting a random token next,
|
||||
given the partial text already generated. If set to float < 1, the smallest set of the
|
||||
most locally typical tokens with probabilities that add up to typical_p or higher are
|
||||
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
||||
"""
|
||||
use_cache: Optional[bool] = None
|
||||
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToAudioParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text To Audio"""
|
||||
|
||||
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
||||
generate_kwargs: Optional[TextToAudioGenerationParameters] = None
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToAudioInput(BaseInferenceType):
|
||||
"""Inputs for Text To Audio inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data"""
|
||||
parameters: Optional[TextToAudioParameters] = None
|
||||
"""Additional inference parameters for Text To Audio"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToAudioOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Text To Audio task"""
|
||||
|
||||
audio: Any
|
||||
"""The generated audio waveform."""
|
||||
sampling_rate: float
|
||||
"""The sampling rate of the generated audio waveform."""
|
||||
@@ -0,0 +1,50 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToImageParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text To Image"""
|
||||
|
||||
guidance_scale: Optional[float] = None
|
||||
"""A higher guidance scale value encourages the model to generate images closely linked to
|
||||
the text prompt, but values too high may cause saturation and other artifacts.
|
||||
"""
|
||||
height: Optional[int] = None
|
||||
"""The height in pixels of the output image"""
|
||||
negative_prompt: Optional[str] = None
|
||||
"""One prompt to guide what NOT to include in image generation."""
|
||||
num_inference_steps: Optional[int] = None
|
||||
"""The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
"""
|
||||
scheduler: Optional[str] = None
|
||||
"""Override the scheduler with a compatible one."""
|
||||
seed: Optional[int] = None
|
||||
"""Seed for the random number generator."""
|
||||
width: Optional[int] = None
|
||||
"""The width in pixels of the output image"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToImageInput(BaseInferenceType):
|
||||
"""Inputs for Text To Image inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data (sometimes called "prompt")"""
|
||||
parameters: Optional[TextToImageParameters] = None
|
||||
"""Additional inference parameters for Text To Image"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToImageOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Text To Image task"""
|
||||
|
||||
image: Any
|
||||
"""The generated image returned as raw bytes in the payload."""
|
||||
@@ -0,0 +1,100 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TextToSpeechEarlyStoppingEnum = Literal["never"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToSpeechGenerationParameters(BaseInferenceType):
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
do_sample: Optional[bool] = None
|
||||
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
||||
early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None
|
||||
"""Controls the stopping condition for beam-based methods."""
|
||||
epsilon_cutoff: Optional[float] = None
|
||||
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
||||
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
||||
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
||||
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
||||
"""
|
||||
eta_cutoff: Optional[float] = None
|
||||
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
||||
float strictly between 0 and 1, a token is only considered if it is greater than either
|
||||
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
||||
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
||||
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
||||
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
||||
for more details.
|
||||
"""
|
||||
max_length: Optional[int] = None
|
||||
"""The maximum length (in tokens) of the generated text, including the input."""
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
||||
min_length: Optional[int] = None
|
||||
"""The minimum length (in tokens) of the generated text, including the input."""
|
||||
min_new_tokens: Optional[int] = None
|
||||
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
||||
num_beam_groups: Optional[int] = None
|
||||
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
||||
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
||||
"""
|
||||
num_beams: Optional[int] = None
|
||||
"""Number of beams to use for beam search."""
|
||||
penalty_alpha: Optional[float] = None
|
||||
"""The value balances the model confidence and the degeneration penalty in contrastive
|
||||
search decoding.
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""The value used to modulate the next token probabilities."""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
||||
top_p: Optional[float] = None
|
||||
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to top_p or higher are kept for generation.
|
||||
"""
|
||||
typical_p: Optional[float] = None
|
||||
"""Local typicality measures how similar the conditional probability of predicting a target
|
||||
token next is to the expected conditional probability of predicting a random token next,
|
||||
given the partial text already generated. If set to float < 1, the smallest set of the
|
||||
most locally typical tokens with probabilities that add up to typical_p or higher are
|
||||
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
||||
"""
|
||||
use_cache: Optional[bool] = None
|
||||
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToSpeechParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text To Speech"""
|
||||
|
||||
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
||||
generate_kwargs: Optional[TextToSpeechGenerationParameters] = None
|
||||
"""Parametrization of the text generation process"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToSpeechInput(BaseInferenceType):
|
||||
"""Inputs for Text To Speech inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data"""
|
||||
parameters: Optional[TextToSpeechParameters] = None
|
||||
"""Additional inference parameters for Text To Speech"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToSpeechOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Text To Speech task"""
|
||||
|
||||
audio: Any
|
||||
"""The generated audio"""
|
||||
sampling_rate: Optional[float] = None
|
||||
"""The sampling rate of the generated audio waveform."""
|
||||
@@ -0,0 +1,46 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToVideoParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Text To Video"""
|
||||
|
||||
guidance_scale: Optional[float] = None
|
||||
"""A higher guidance scale value encourages the model to generate videos closely linked to
|
||||
the text prompt, but values too high may cause saturation and other artifacts.
|
||||
"""
|
||||
negative_prompt: Optional[List[str]] = None
|
||||
"""One or several prompt to guide what NOT to include in video generation."""
|
||||
num_frames: Optional[float] = None
|
||||
"""The num_frames parameter determines how many video frames are generated."""
|
||||
num_inference_steps: Optional[int] = None
|
||||
"""The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
video at the expense of slower inference.
|
||||
"""
|
||||
seed: Optional[int] = None
|
||||
"""Seed for the random number generator."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToVideoInput(BaseInferenceType):
|
||||
"""Inputs for Text To Video inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data (sometimes called "prompt")"""
|
||||
parameters: Optional[TextToVideoParameters] = None
|
||||
"""Additional inference parameters for Text To Video"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TextToVideoOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Text To Video task"""
|
||||
|
||||
video: Any
|
||||
"""The generated video returned as raw bytes in the payload."""
|
||||
@@ -0,0 +1,51 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TokenClassificationAggregationStrategy = Literal["none", "simple", "first", "average", "max"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TokenClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Token Classification"""
|
||||
|
||||
aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None
|
||||
"""The strategy used to fuse tokens based on model predictions"""
|
||||
ignore_labels: Optional[List[str]] = None
|
||||
"""A list of labels to ignore"""
|
||||
stride: Optional[int] = None
|
||||
"""The number of overlapping tokens between chunks when splitting the input text."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TokenClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Token Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input text data"""
|
||||
parameters: Optional[TokenClassificationParameters] = None
|
||||
"""Additional inference parameters for Token Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TokenClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Token Classification task"""
|
||||
|
||||
end: int
|
||||
"""The character position in the input where this group ends."""
|
||||
score: float
|
||||
"""The associated score / probability"""
|
||||
start: int
|
||||
"""The character position in the input where this group begins."""
|
||||
word: str
|
||||
"""The corresponding text"""
|
||||
entity: Optional[str] = None
|
||||
"""The predicted label for a single token"""
|
||||
entity_group: Optional[str] = None
|
||||
"""The predicted label for a group of one or more tokens"""
|
||||
@@ -0,0 +1,49 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
TranslationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TranslationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Translation"""
|
||||
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
"""Whether to clean up the potential extra spaces in the text output."""
|
||||
generate_parameters: Optional[Dict[str, Any]] = None
|
||||
"""Additional parametrization of the text generation algorithm."""
|
||||
src_lang: Optional[str] = None
|
||||
"""The source language of the text. Required for models that can translate from multiple
|
||||
languages.
|
||||
"""
|
||||
tgt_lang: Optional[str] = None
|
||||
"""Target language to translate to. Required for models that can translate to multiple
|
||||
languages.
|
||||
"""
|
||||
truncation: Optional["TranslationTruncationStrategy"] = None
|
||||
"""The truncation strategy to use."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TranslationInput(BaseInferenceType):
|
||||
"""Inputs for Translation inference"""
|
||||
|
||||
inputs: str
|
||||
"""The text to translate."""
|
||||
parameters: Optional[TranslationParameters] = None
|
||||
"""Additional inference parameters for Translation"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class TranslationOutput(BaseInferenceType):
|
||||
"""Outputs of inference for the Translation task"""
|
||||
|
||||
translation_text: str
|
||||
"""The translated text."""
|
||||
@@ -0,0 +1,45 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
VideoClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VideoClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Video Classification"""
|
||||
|
||||
frame_sampling_rate: Optional[int] = None
|
||||
"""The sampling rate used to select frames from the video."""
|
||||
function_to_apply: Optional["VideoClassificationOutputTransform"] = None
|
||||
"""The function to apply to the model outputs in order to retrieve the scores."""
|
||||
num_frames: Optional[int] = None
|
||||
"""The number of sampled frames to consider for classification."""
|
||||
top_k: Optional[int] = None
|
||||
"""When specified, limits the output to the top K most probable classes."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VideoClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Video Classification inference"""
|
||||
|
||||
inputs: Any
|
||||
"""The input video data"""
|
||||
parameters: Optional[VideoClassificationParameters] = None
|
||||
"""Additional inference parameters for Video Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VideoClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Video Classification task"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,49 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VisualQuestionAnsweringInputData(BaseInferenceType):
|
||||
"""One (image, question) pair to answer"""
|
||||
|
||||
image: Any
|
||||
"""The image."""
|
||||
question: str
|
||||
"""The question to answer based on the image."""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VisualQuestionAnsweringParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Visual Question Answering"""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
"""The number of answers to return (will be chosen by order of likelihood). Note that we
|
||||
return less than topk answers if there are not enough options available within the
|
||||
context.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VisualQuestionAnsweringInput(BaseInferenceType):
|
||||
"""Inputs for Visual Question Answering inference"""
|
||||
|
||||
inputs: VisualQuestionAnsweringInputData
|
||||
"""One (image, question) pair to answer"""
|
||||
parameters: Optional[VisualQuestionAnsweringParameters] = None
|
||||
"""Additional inference parameters for Visual Question Answering"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class VisualQuestionAnsweringOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Visual Question Answering task"""
|
||||
|
||||
score: float
|
||||
"""The associated score / probability"""
|
||||
answer: Optional[str] = None
|
||||
"""The answer to the question"""
|
||||
@@ -0,0 +1,45 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Zero Shot Classification"""
|
||||
|
||||
candidate_labels: List[str]
|
||||
"""The set of possible class labels to classify the text into."""
|
||||
hypothesis_template: Optional[str] = None
|
||||
"""The sentence used in conjunction with `candidate_labels` to attempt the text
|
||||
classification by replacing the placeholder with the candidate labels.
|
||||
"""
|
||||
multi_label: Optional[bool] = None
|
||||
"""Whether multiple candidate labels can be true. If false, the scores are normalized such
|
||||
that the sum of the label likelihoods for each sequence is 1. If true, the labels are
|
||||
considered independent and probabilities are normalized for each candidate.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Zero Shot Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The text to classify"""
|
||||
parameters: ZeroShotClassificationParameters
|
||||
"""Additional inference parameters for Zero Shot Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Zero Shot Classification task"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,40 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List, Optional
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotImageClassificationParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Zero Shot Image Classification"""
|
||||
|
||||
candidate_labels: List[str]
|
||||
"""The candidate labels for this image"""
|
||||
hypothesis_template: Optional[str] = None
|
||||
"""The sentence used in conjunction with `candidate_labels` to attempt the image
|
||||
classification by replacing the placeholder with the candidate labels.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotImageClassificationInput(BaseInferenceType):
|
||||
"""Inputs for Zero Shot Image Classification inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data to classify as a base64-encoded string."""
|
||||
parameters: ZeroShotImageClassificationParameters
|
||||
"""Additional inference parameters for Zero Shot Image Classification"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotImageClassificationOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Zero Shot Image Classification task"""
|
||||
|
||||
label: str
|
||||
"""The predicted class label."""
|
||||
score: float
|
||||
"""The corresponding probability."""
|
||||
@@ -0,0 +1,52 @@
|
||||
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
||||
#
|
||||
# See:
|
||||
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
||||
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
||||
from typing import List
|
||||
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotObjectDetectionParameters(BaseInferenceType):
|
||||
"""Additional inference parameters for Zero Shot Object Detection"""
|
||||
|
||||
candidate_labels: List[str]
|
||||
"""The candidate labels for this image"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotObjectDetectionInput(BaseInferenceType):
|
||||
"""Inputs for Zero Shot Object Detection inference"""
|
||||
|
||||
inputs: str
|
||||
"""The input image data as a base64-encoded string."""
|
||||
parameters: ZeroShotObjectDetectionParameters
|
||||
"""Additional inference parameters for Zero Shot Object Detection"""
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
|
||||
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
||||
image.
|
||||
"""
|
||||
|
||||
xmax: int
|
||||
xmin: int
|
||||
ymax: int
|
||||
ymin: int
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
|
||||
"""Outputs of inference for the Zero Shot Object Detection task"""
|
||||
|
||||
box: ZeroShotObjectDetectionBoundingBox
|
||||
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
||||
image.
|
||||
"""
|
||||
label: str
|
||||
"""A candidate label"""
|
||||
score: float
|
||||
"""The associated score / probability"""
|
||||
@@ -0,0 +1,141 @@
|
||||
from typing import Dict, Literal
|
||||
|
||||
from ._common import TaskProviderHelper
|
||||
from .black_forest_labs import BlackForestLabsTextToImageTask
|
||||
from .cerebras import CerebrasConversationalTask
|
||||
from .cohere import CohereConversationalTask
|
||||
from .fal_ai import (
|
||||
FalAIAutomaticSpeechRecognitionTask,
|
||||
FalAITextToImageTask,
|
||||
FalAITextToSpeechTask,
|
||||
FalAITextToVideoTask,
|
||||
)
|
||||
from .fireworks_ai import FireworksAIConversationalTask
|
||||
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
|
||||
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
|
||||
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
|
||||
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
|
||||
from .openai import OpenAIConversationalTask
|
||||
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
|
||||
from .sambanova import SambanovaConversationalTask
|
||||
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
||||
|
||||
|
||||
PROVIDER_T = Literal[
|
||||
"black-forest-labs",
|
||||
"cerebras",
|
||||
"cohere",
|
||||
"fal-ai",
|
||||
"fireworks-ai",
|
||||
"hf-inference",
|
||||
"hyperbolic",
|
||||
"nebius",
|
||||
"novita",
|
||||
"openai",
|
||||
"replicate",
|
||||
"sambanova",
|
||||
"together",
|
||||
]
|
||||
|
||||
PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
||||
"black-forest-labs": {
|
||||
"text-to-image": BlackForestLabsTextToImageTask(),
|
||||
},
|
||||
"cerebras": {
|
||||
"conversational": CerebrasConversationalTask(),
|
||||
},
|
||||
"cohere": {
|
||||
"conversational": CohereConversationalTask(),
|
||||
},
|
||||
"fal-ai": {
|
||||
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
|
||||
"text-to-image": FalAITextToImageTask(),
|
||||
"text-to-speech": FalAITextToSpeechTask(),
|
||||
"text-to-video": FalAITextToVideoTask(),
|
||||
},
|
||||
"fireworks-ai": {
|
||||
"conversational": FireworksAIConversationalTask(),
|
||||
},
|
||||
"hf-inference": {
|
||||
"text-to-image": HFInferenceTask("text-to-image"),
|
||||
"conversational": HFInferenceConversational(),
|
||||
"text-generation": HFInferenceTask("text-generation"),
|
||||
"text-classification": HFInferenceTask("text-classification"),
|
||||
"question-answering": HFInferenceTask("question-answering"),
|
||||
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
|
||||
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
|
||||
"fill-mask": HFInferenceTask("fill-mask"),
|
||||
"feature-extraction": HFInferenceTask("feature-extraction"),
|
||||
"image-classification": HFInferenceBinaryInputTask("image-classification"),
|
||||
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
|
||||
"document-question-answering": HFInferenceTask("document-question-answering"),
|
||||
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
|
||||
"object-detection": HFInferenceBinaryInputTask("object-detection"),
|
||||
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
|
||||
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
|
||||
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
|
||||
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
|
||||
"sentence-similarity": HFInferenceTask("sentence-similarity"),
|
||||
"table-question-answering": HFInferenceTask("table-question-answering"),
|
||||
"tabular-classification": HFInferenceTask("tabular-classification"),
|
||||
"text-to-speech": HFInferenceTask("text-to-speech"),
|
||||
"token-classification": HFInferenceTask("token-classification"),
|
||||
"translation": HFInferenceTask("translation"),
|
||||
"summarization": HFInferenceTask("summarization"),
|
||||
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
|
||||
},
|
||||
"hyperbolic": {
|
||||
"text-to-image": HyperbolicTextToImageTask(),
|
||||
"conversational": HyperbolicTextGenerationTask("conversational"),
|
||||
"text-generation": HyperbolicTextGenerationTask("text-generation"),
|
||||
},
|
||||
"nebius": {
|
||||
"text-to-image": NebiusTextToImageTask(),
|
||||
"conversational": NebiusConversationalTask(),
|
||||
"text-generation": NebiusTextGenerationTask(),
|
||||
},
|
||||
"novita": {
|
||||
"text-generation": NovitaTextGenerationTask(),
|
||||
"conversational": NovitaConversationalTask(),
|
||||
"text-to-video": NovitaTextToVideoTask(),
|
||||
},
|
||||
"openai": {
|
||||
"conversational": OpenAIConversationalTask(),
|
||||
},
|
||||
"replicate": {
|
||||
"text-to-image": ReplicateTask("text-to-image"),
|
||||
"text-to-speech": ReplicateTextToSpeechTask(),
|
||||
"text-to-video": ReplicateTask("text-to-video"),
|
||||
},
|
||||
"sambanova": {
|
||||
"conversational": SambanovaConversationalTask(),
|
||||
},
|
||||
"together": {
|
||||
"text-to-image": TogetherTextToImageTask(),
|
||||
"conversational": TogetherConversationalTask(),
|
||||
"text-generation": TogetherTextGenerationTask(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_helper(provider: PROVIDER_T, task: str) -> TaskProviderHelper:
|
||||
"""Get provider helper instance by name and task.
|
||||
|
||||
Args:
|
||||
provider (str): Name of the provider
|
||||
task (str): Name of the task
|
||||
|
||||
Returns:
|
||||
TaskProviderHelper: Helper instance for the specified provider and task
|
||||
|
||||
Raises:
|
||||
ValueError: If provider or task is not supported
|
||||
"""
|
||||
if provider not in PROVIDERS:
|
||||
raise ValueError(f"Provider '{provider}' not supported. Available providers: {list(PROVIDERS.keys())}")
|
||||
if task not in PROVIDERS[provider]:
|
||||
raise ValueError(
|
||||
f"Task '{task}' not supported for provider '{provider}'. "
|
||||
f"Available tasks: {list(PROVIDERS[provider].keys())}"
|
||||
)
|
||||
return PROVIDERS[provider][task]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,245 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import constants
|
||||
from huggingface_hub.inference._common import RequestParameters
|
||||
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Dev purposes only.
|
||||
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
||||
# for a given Inference Provider, you can add it to the following dictionary.
|
||||
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
|
||||
# "HF model ID" => "Model ID on Inference Provider's side"
|
||||
#
|
||||
# Example:
|
||||
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
||||
"cerebras": {},
|
||||
"cohere": {},
|
||||
"fal-ai": {},
|
||||
"fireworks-ai": {},
|
||||
"hf-inference": {},
|
||||
"hyperbolic": {},
|
||||
"nebius": {},
|
||||
"replicate": {},
|
||||
"sambanova": {},
|
||||
"together": {},
|
||||
}
|
||||
|
||||
|
||||
def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
class TaskProviderHelper:
|
||||
"""Base class for task-specific provider helpers."""
|
||||
|
||||
def __init__(self, provider: str, base_url: str, task: str) -> None:
|
||||
self.provider = provider
|
||||
self.task = task
|
||||
self.base_url = base_url
|
||||
|
||||
def prepare_request(
|
||||
self,
|
||||
*,
|
||||
inputs: Any,
|
||||
parameters: Dict[str, Any],
|
||||
headers: Dict,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
extra_payload: Optional[Dict[str, Any]] = None,
|
||||
) -> RequestParameters:
|
||||
"""
|
||||
Prepare the request to be sent to the provider.
|
||||
|
||||
Each step (api_key, model, headers, url, payload) can be customized in subclasses.
|
||||
"""
|
||||
# api_key from user, or local token, or raise error
|
||||
api_key = self._prepare_api_key(api_key)
|
||||
|
||||
# mapped model from HF model ID
|
||||
mapped_model = self._prepare_mapped_model(model)
|
||||
|
||||
# default HF headers + user headers (to customize in subclasses)
|
||||
headers = self._prepare_headers(headers, api_key)
|
||||
|
||||
# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
|
||||
url = self._prepare_url(api_key, mapped_model)
|
||||
|
||||
# prepare payload (to customize in subclasses)
|
||||
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
|
||||
if payload is not None:
|
||||
payload = recursive_merge(payload, extra_payload or {})
|
||||
|
||||
# body data (to customize in subclasses)
|
||||
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
|
||||
|
||||
# check if both payload and data are set and return
|
||||
if payload is not None and data is not None:
|
||||
raise ValueError("Both payload and data cannot be set in the same request.")
|
||||
if payload is None and data is None:
|
||||
raise ValueError("Either payload or data must be set in the request.")
|
||||
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
response: Union[bytes, Dict],
|
||||
request_params: Optional[RequestParameters] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Return the response in the expected format.
|
||||
|
||||
Override this method in subclasses for customized response handling."""
|
||||
return response
|
||||
|
||||
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
||||
"""Return the API key to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
if api_key is None:
|
||||
api_key = get_token()
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
f"You must provide an api_key to work with {self.provider} API or log in with `huggingface-cli login`."
|
||||
)
|
||||
return api_key
|
||||
|
||||
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
||||
"""Return the mapped model ID to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
if model is None:
|
||||
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
|
||||
|
||||
# hardcoded mapping for local testing
|
||||
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
|
||||
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
|
||||
|
||||
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
|
||||
if provider_mapping is None:
|
||||
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
|
||||
|
||||
if provider_mapping.task != self.task:
|
||||
raise ValueError(
|
||||
f"Model {model} is not supported for task {self.task} and provider {self.provider}. "
|
||||
f"Supported task: {provider_mapping.task}."
|
||||
)
|
||||
|
||||
if provider_mapping.status == "staging":
|
||||
logger.warning(
|
||||
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
|
||||
)
|
||||
return provider_mapping.provider_id
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
"""Return the headers to use for the request.
|
||||
|
||||
Override this method in subclasses for customized headers.
|
||||
"""
|
||||
return {**build_hf_headers(token=api_key), **headers}
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
"""Return the URL to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
base_url = self._prepare_base_url(api_key)
|
||||
route = self._prepare_route(mapped_model, api_key)
|
||||
return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
|
||||
|
||||
def _prepare_base_url(self, api_key: str) -> str:
|
||||
"""Return the base URL to use for the request.
|
||||
|
||||
Usually not overwritten in subclasses."""
|
||||
# Route to the proxy if the api_key is a HF TOKEN
|
||||
if api_key.startswith("hf_"):
|
||||
logger.info(f"Calling '{self.provider}' provider through Hugging Face router.")
|
||||
return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider)
|
||||
else:
|
||||
logger.info(f"Calling '{self.provider}' provider directly.")
|
||||
return self.base_url
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
"""Return the route to use for the request.
|
||||
|
||||
Override this method in subclasses for customized routes.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
"""Return the payload to use for the request, as a dict.
|
||||
|
||||
Override this method in subclasses for customized payloads.
|
||||
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _prepare_payload_as_bytes(
|
||||
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
|
||||
) -> Optional[bytes]:
|
||||
"""Return the body to use for the request, as bytes.
|
||||
|
||||
Override this method in subclasses for customized body data.
|
||||
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class BaseConversationalTask(TaskProviderHelper):
|
||||
"""
|
||||
Base class for conversational (chat completion) tasks.
|
||||
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
|
||||
"""
|
||||
|
||||
def __init__(self, provider: str, base_url: str):
|
||||
super().__init__(provider=provider, base_url=base_url, task="conversational")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
|
||||
|
||||
|
||||
class BaseTextGenerationTask(TaskProviderHelper):
|
||||
"""
|
||||
Base class for text-generation (completion) tasks.
|
||||
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
|
||||
"""
|
||||
|
||||
def __init__(self, provider: str, base_url: str):
|
||||
super().__init__(provider=provider, base_url=base_url, task="text-generation")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _fetch_inference_provider_mapping(model: str) -> Dict:
|
||||
"""
|
||||
Fetch provider mappings for a model from the Hub.
|
||||
"""
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
|
||||
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
|
||||
provider_mapping = info.inference_provider_mapping
|
||||
if provider_mapping is None:
|
||||
raise ValueError(f"No provider mapping found for model {model}")
|
||||
return provider_mapping
|
||||
|
||||
|
||||
def recursive_merge(dict1: Dict, dict2: Dict) -> Dict:
|
||||
return {
|
||||
**dict1,
|
||||
**{
|
||||
key: recursive_merge(dict1[key], value)
|
||||
if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict))
|
||||
else value
|
||||
for key, value in dict2.items()
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import logging
|
||||
from huggingface_hub.utils._http import get_session
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_POLLING_ATTEMPTS = 6
|
||||
POLLING_INTERVAL = 1.0
|
||||
|
||||
|
||||
class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
headers = super()._prepare_headers(headers, api_key)
|
||||
if not api_key.startswith("hf_"):
|
||||
_ = headers.pop("authorization")
|
||||
headers["X-Key"] = api_key
|
||||
return headers
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return f"/v1/{mapped_model}"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "num_inference_steps" in parameters:
|
||||
parameters["steps"] = parameters.pop("num_inference_steps")
|
||||
if "guidance_scale" in parameters:
|
||||
parameters["guidance"] = parameters.pop("guidance_scale")
|
||||
|
||||
return {"prompt": inputs, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
"""
|
||||
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
||||
"""
|
||||
url = _as_dict(response).get("polling_url")
|
||||
session = get_session()
|
||||
for _ in range(MAX_POLLING_ATTEMPTS):
|
||||
time.sleep(POLLING_INTERVAL)
|
||||
|
||||
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
||||
response.raise_for_status() # type: ignore
|
||||
response_json: Dict = response.json() # type: ignore
|
||||
status = response_json.get("status")
|
||||
logger.info(
|
||||
f"Polling generation result from {url}. Current status: {status}. "
|
||||
f"Will retry after {POLLING_INTERVAL} seconds if not ready."
|
||||
)
|
||||
|
||||
if (
|
||||
status == "Ready"
|
||||
and isinstance(response_json.get("result"), dict)
|
||||
and (sample_url := response_json["result"].get("sample"))
|
||||
):
|
||||
image_resp = session.get(sample_url)
|
||||
image_resp.raise_for_status()
|
||||
return image_resp.content
|
||||
|
||||
raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.")
|
||||
@@ -0,0 +1,6 @@
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask
|
||||
|
||||
|
||||
class CerebrasConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="cerebras", base_url="https://api.cerebras.ai")
|
||||
@@ -0,0 +1,15 @@
|
||||
from huggingface_hub.inference._providers._common import (
|
||||
BaseConversationalTask,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDER = "cohere"
|
||||
_BASE_URL = "https://api.cohere.com"
|
||||
|
||||
|
||||
class CohereConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/compatibility/v1/chat/completions"
|
||||
@@ -0,0 +1,147 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import get_session, hf_raise_for_status
|
||||
from huggingface_hub.utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Arbitrary polling interval
|
||||
_POLLING_INTERVAL = 0.5
|
||||
|
||||
|
||||
class FalAITask(TaskProviderHelper, ABC):
|
||||
def __init__(self, task: str):
|
||||
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
|
||||
|
||||
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
||||
headers = super()._prepare_headers(headers, api_key)
|
||||
if not api_key.startswith("hf_"):
|
||||
headers["authorization"] = f"Key {api_key}"
|
||||
return headers
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return f"/{mapped_model}"
|
||||
|
||||
|
||||
class FalAIAutomaticSpeechRecognitionTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("automatic-speech-recognition")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
|
||||
# If input is a URL, pass it directly
|
||||
audio_url = inputs
|
||||
else:
|
||||
# If input is a file path, read it first
|
||||
if isinstance(inputs, str):
|
||||
with open(inputs, "rb") as f:
|
||||
inputs = f.read()
|
||||
|
||||
audio_b64 = base64.b64encode(inputs).decode()
|
||||
content_type = "audio/mpeg"
|
||||
audio_url = f"data:{content_type};base64,{audio_b64}"
|
||||
|
||||
return {"audio_url": audio_url, **filter_none(parameters)}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
text = _as_dict(response)["text"]
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
|
||||
return text
|
||||
|
||||
|
||||
class FalAITextToImageTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-image")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "width" in parameters and "height" in parameters:
|
||||
parameters["image_size"] = {
|
||||
"width": parameters.pop("width"),
|
||||
"height": parameters.pop("height"),
|
||||
}
|
||||
return {"prompt": inputs, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
url = _as_dict(response)["images"][0]["url"]
|
||||
return get_session().get(url).content
|
||||
|
||||
|
||||
class FalAITextToSpeechTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-speech")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"lyrics": inputs, **filter_none(parameters)}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
url = _as_dict(response)["audio"]["url"]
|
||||
return get_session().get(url).content
|
||||
|
||||
|
||||
class FalAITextToVideoTask(FalAITask):
|
||||
def __init__(self):
|
||||
super().__init__("text-to-video")
|
||||
|
||||
def _prepare_base_url(self, api_key: str) -> str:
|
||||
if api_key.startswith("hf_"):
|
||||
return super()._prepare_base_url(api_key)
|
||||
else:
|
||||
logger.info(f"Calling '{self.provider}' provider directly.")
|
||||
return "https://queue.fal.run"
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
if api_key.startswith("hf_"):
|
||||
# Use the queue subdomain for HF routing
|
||||
return f"/{mapped_model}?_subdomain=queue"
|
||||
return f"/{mapped_model}"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return {"prompt": inputs, **filter_none(parameters)}
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
response: Union[bytes, Dict],
|
||||
request_params: Optional[RequestParameters] = None,
|
||||
) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
|
||||
request_id = response_dict.get("request_id")
|
||||
if not request_id:
|
||||
raise ValueError("No request ID found in the response")
|
||||
if request_params is None:
|
||||
raise ValueError(
|
||||
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
|
||||
)
|
||||
|
||||
# extract the base url and query params
|
||||
parsed_url = urlparse(request_params.url)
|
||||
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
|
||||
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
|
||||
|
||||
# extracting the provider model id for status and result urls
|
||||
# from the response as it might be different from the mapped model in `request_params.url`
|
||||
model_id = urlparse(response_dict.get("response_url")).path
|
||||
status_url = f"{base_url}{str(model_id)}/status{query_param}"
|
||||
result_url = f"{base_url}{str(model_id)}{query_param}"
|
||||
|
||||
status = response_dict.get("status")
|
||||
logger.info("Generating the video.. this can take several minutes.")
|
||||
while status != "COMPLETED":
|
||||
time.sleep(_POLLING_INTERVAL)
|
||||
status_response = get_session().get(status_url, headers=request_params.headers)
|
||||
hf_raise_for_status(status_response)
|
||||
status = status_response.json().get("status")
|
||||
|
||||
response = get_session().get(result_url, headers=request_params.headers).json()
|
||||
url = _as_dict(response)["video"]["url"]
|
||||
return get_session().get(url).content
|
||||
@@ -0,0 +1,9 @@
|
||||
from ._common import BaseConversationalTask
|
||||
|
||||
|
||||
class FireworksAIConversationalTask(BaseConversationalTask):
|
||||
def __init__(self):
|
||||
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/inference/v1/chat/completions"
|
||||
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from huggingface_hub import constants
|
||||
from huggingface_hub.inference._common import _b64_encode, _open_as_binary
|
||||
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
||||
from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
|
||||
|
||||
|
||||
class HFInferenceTask(TaskProviderHelper):
|
||||
"""Base class for HF Inference API tasks."""
|
||||
|
||||
def __init__(self, task: str):
|
||||
super().__init__(
|
||||
provider="hf-inference",
|
||||
base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"),
|
||||
task=task,
|
||||
)
|
||||
|
||||
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
||||
# special case: for HF Inference we allow not providing an API key
|
||||
return api_key or get_token() # type: ignore[return-value]
|
||||
|
||||
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
||||
if model is not None and model.startswith(("http://", "https://")):
|
||||
return model
|
||||
model_id = model if model is not None else _fetch_recommended_models().get(self.task)
|
||||
if model_id is None:
|
||||
raise ValueError(
|
||||
f"Task {self.task} has no recommended model for HF Inference. Please specify a model"
|
||||
" explicitly. Visit https://huggingface.co/tasks for more info."
|
||||
)
|
||||
_check_supported_task(model_id, self.task)
|
||||
return model_id
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
|
||||
if mapped_model.startswith(("http://", "https://")):
|
||||
return mapped_model
|
||||
return (
|
||||
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
|
||||
f"{self.base_url}/pipeline/{self.task}/{mapped_model}"
|
||||
if self.task in ("feature-extraction", "sentence-similarity")
|
||||
# Otherwise, we use the default endpoint
|
||||
else f"{self.base_url}/models/{mapped_model}"
|
||||
)
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
if isinstance(inputs, bytes):
|
||||
raise ValueError(f"Unexpected binary input for task {self.task}.")
|
||||
if isinstance(inputs, Path):
|
||||
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
|
||||
return {"inputs": inputs, "parameters": filter_none(parameters)}
|
||||
|
||||
|
||||
class HFInferenceBinaryInputTask(HFInferenceTask):
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
return None
|
||||
|
||||
def _prepare_payload_as_bytes(
|
||||
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
|
||||
) -> Optional[bytes]:
|
||||
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
|
||||
extra_payload = extra_payload or {}
|
||||
has_parameters = len(parameters) > 0 or len(extra_payload) > 0
|
||||
|
||||
# Raise if not a binary object or a local path or a URL.
|
||||
if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str):
|
||||
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}")
|
||||
|
||||
# Send inputs as raw content when no parameters are provided
|
||||
if not has_parameters:
|
||||
with _open_as_binary(inputs) as data:
|
||||
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
||||
return data_as_bytes
|
||||
|
||||
# Otherwise encode as b64
|
||||
return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8")
|
||||
|
||||
|
||||
class HFInferenceConversational(HFInferenceTask):
|
||||
def __init__(self):
|
||||
super().__init__("conversational")
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
payload_model = parameters.get("model") or mapped_model
|
||||
|
||||
if payload_model is None or payload_model.startswith(("http://", "https://")):
|
||||
payload_model = "dummy"
|
||||
|
||||
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
|
||||
|
||||
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
||||
base_url = (
|
||||
mapped_model
|
||||
if mapped_model.startswith(("http://", "https://"))
|
||||
else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}"
|
||||
)
|
||||
return _build_chat_completion_url(base_url)
|
||||
|
||||
|
||||
def _build_chat_completion_url(model_url: str) -> str:
|
||||
# Strip trailing /
|
||||
model_url = model_url.rstrip("/")
|
||||
|
||||
# Append /chat/completions if not already present
|
||||
if model_url.endswith("/v1"):
|
||||
model_url += "/chat/completions"
|
||||
|
||||
# Append /v1/chat/completions if not already present
|
||||
if not model_url.endswith("/chat/completions"):
|
||||
model_url += "/v1/chat/completions"
|
||||
|
||||
return model_url
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
|
||||
response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers())
|
||||
hf_raise_for_status(response)
|
||||
return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _check_supported_task(model: str, task: str) -> None:
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
|
||||
model_info = HfApi().model_info(model)
|
||||
pipeline_tag = model_info.pipeline_tag
|
||||
tags = model_info.tags or []
|
||||
is_conversational = "conversational" in tags
|
||||
if task in ("text-generation", "conversational"):
|
||||
if pipeline_tag == "text-generation":
|
||||
# text-generation + conversational tag -> both tasks allowed
|
||||
if is_conversational:
|
||||
return
|
||||
# text-generation without conversational tag -> only text-generation allowed
|
||||
if task == "text-generation":
|
||||
return
|
||||
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
|
||||
|
||||
if pipeline_tag == "text2text-generation":
|
||||
if task == "text-generation":
|
||||
return
|
||||
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
|
||||
|
||||
if pipeline_tag == "image-text-to-text":
|
||||
if is_conversational and task == "conversational":
|
||||
return # Only conversational allowed if tagged as conversational
|
||||
raise ValueError("Non-conversational image-text-to-text task is not supported.")
|
||||
|
||||
if (
|
||||
task in ("feature-extraction", "sentence-similarity")
|
||||
and pipeline_tag in ("feature-extraction", "sentence-similarity")
|
||||
and task in tags
|
||||
):
|
||||
# feature-extraction and sentence-similarity are interchangeable for HF Inference
|
||||
return
|
||||
|
||||
# For all other tasks, just check pipeline tag
|
||||
if pipeline_tag != task:
|
||||
raise ValueError(
|
||||
f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'"
|
||||
)
|
||||
return
|
||||
@@ -0,0 +1,43 @@
|
||||
import base64
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
||||
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
|
||||
|
||||
|
||||
class HyperbolicTextToImageTask(TaskProviderHelper):
|
||||
def __init__(self):
|
||||
super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image")
|
||||
|
||||
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
||||
return "/v1/images/generations"
|
||||
|
||||
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
||||
parameters = filter_none(parameters)
|
||||
if "num_inference_steps" in parameters:
|
||||
parameters["steps"] = parameters.pop("num_inference_steps")
|
||||
if "guidance_scale" in parameters:
|
||||
parameters["cfg_scale"] = parameters.pop("guidance_scale")
|
||||
# For Hyperbolic, the width and height are required parameters
|
||||
if "width" not in parameters:
|
||||
parameters["width"] = 512
|
||||
if "height" not in parameters:
|
||||
parameters["height"] = 512
|
||||
return {"prompt": inputs, "model_name": mapped_model, **parameters}
|
||||
|
||||
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
return base64.b64decode(response_dict["images"][0]["image"])
|
||||
|
||||
|
||||
class HyperbolicTextGenerationTask(BaseConversationalTask):
|
||||
"""
|
||||
Special case for Hyperbolic, where text-generation task is handled as a conversational task.
|
||||
"""
|
||||
|
||||
def __init__(self, task: str):
|
||||
super().__init__(
|
||||
provider="hyperbolic",
|
||||
base_url="https://api.hyperbolic.xyz",
|
||||
)
|
||||
self.task = task
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user