structure saas with tools
This commit is contained in:
199
.venv/lib/python3.10/site-packages/huggingface_hub/utils/_xet.py
Normal file
199
.venv/lib/python3.10/site-packages/huggingface_hub/utils/_xet.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .. import constants
|
||||
from . import get_session, hf_raise_for_status, validate_hf_hub_args
|
||||
|
||||
|
||||
class XetTokenType(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class XetFileData:
|
||||
file_hash: str
|
||||
refresh_route: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class XetConnectionInfo:
|
||||
access_token: str
|
||||
expiration_unix_epoch: int
|
||||
endpoint: str
|
||||
|
||||
|
||||
def parse_xet_file_data_from_response(response: requests.Response) -> Optional[XetFileData]:
|
||||
"""
|
||||
Parse XET file metadata from an HTTP response.
|
||||
|
||||
This function extracts XET file metadata from the HTTP headers or HTTP links
|
||||
of a given response object. If the required metadata is not found, it returns `None`.
|
||||
|
||||
Args:
|
||||
response (`requests.Response`):
|
||||
The HTTP response object containing headers dict and links dict to extract the XET metadata from.
|
||||
Returns:
|
||||
`Optional[XetFileData]`:
|
||||
An instance of `XetFileData` containing the file hash and refresh route if the metadata
|
||||
is found. Returns `None` if the required metadata is missing.
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
try:
|
||||
file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
|
||||
|
||||
if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
|
||||
refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
|
||||
else:
|
||||
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
return XetFileData(
|
||||
file_hash=file_hash,
|
||||
refresh_route=refresh_route,
|
||||
)
|
||||
|
||||
|
||||
def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[XetConnectionInfo]:
|
||||
"""
|
||||
Parse XET connection info from the HTTP headers or return None if not found.
|
||||
Args:
|
||||
headers (`Dict`):
|
||||
HTTP headers to extract the XET metadata from.
|
||||
Returns:
|
||||
`XetConnectionInfo` or `None`:
|
||||
The information needed to connect to the XET storage service.
|
||||
Returns `None` if the headers do not contain the XET connection info.
|
||||
"""
|
||||
try:
|
||||
endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
|
||||
access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
|
||||
expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
|
||||
except (KeyError, ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return XetConnectionInfo(
|
||||
endpoint=endpoint,
|
||||
access_token=access_token,
|
||||
expiration_unix_epoch=expiration_unix_epoch,
|
||||
)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def refresh_xet_connection_info(
|
||||
*,
|
||||
file_data: XetFileData,
|
||||
headers: Dict[str, str],
|
||||
endpoint: Optional[str] = None,
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Utilizes the information in the parsed metadata to request the Hub xet connection information.
|
||||
This includes the access token, expiration, and XET service URL.
|
||||
Args:
|
||||
file_data: (`XetFileData`):
|
||||
The file data needed to refresh the xet connection information.
|
||||
headers (`Dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
endpoint (`str`, `optional`):
|
||||
The endpoint to use for the request. Defaults to the Hub endpoint.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
if file_data.refresh_route is None:
|
||||
raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
|
||||
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
||||
|
||||
# TODO: An upcoming version of hub will prepend the endpoint to the refresh route in
|
||||
# the headers. Once that's deployed we can call fetch on the refresh route directly.
|
||||
url = file_data.refresh_route
|
||||
if url.startswith("/"):
|
||||
url = f"{endpoint}{url}"
|
||||
|
||||
return _fetch_xet_connection_info_with_url(url, headers)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def fetch_xet_connection_info_from_repo_info(
|
||||
*,
|
||||
token_type: XetTokenType,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
revision: Optional[str] = None,
|
||||
headers: Dict[str, str],
|
||||
endpoint: Optional[str] = None,
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Uses the repo info to request a xet access token from Hub.
|
||||
Args:
|
||||
token_type (`XetTokenType`):
|
||||
Type of the token to request: `"read"` or `"write"`.
|
||||
repo_id (`str`):
|
||||
A namespace (user or an organization) and a repo name separated by a `/`.
|
||||
repo_type (`str`):
|
||||
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
||||
revision (`str`, `optional`):
|
||||
The revision of the repo to get the token for.
|
||||
headers (`Dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
endpoint (`str`, `optional`):
|
||||
The endpoint to use for the request. Defaults to the Hub endpoint.
|
||||
params (`Dict[str, str]`, `optional`):
|
||||
Additional parameters to pass with the request.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
||||
url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}"
|
||||
return _fetch_xet_connection_info_with_url(url, headers, params)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def _fetch_xet_connection_info_with_url(
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Requests the xet connection info from the supplied URL. This includes the
|
||||
access token, expiration time, and endpoint to use for the xet storage service.
|
||||
Args:
|
||||
url: (`str`):
|
||||
The access token endpoint URL.
|
||||
headers (`Dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
params (`Dict[str, str]`, `optional`):
|
||||
Additional parameters to pass with the request.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
resp = get_session().get(headers=headers, url=url, params=params)
|
||||
hf_raise_for_status(resp)
|
||||
|
||||
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
|
||||
if metadata is None:
|
||||
raise ValueError("Xet headers have not been correctly set by the server.")
|
||||
return metadata
|
||||
Reference in New Issue
Block a user