structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F401
"""Contains helpers to serialize tensors."""
from ._base import StateDictSplit, split_state_dict_into_shards_factory
from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
from ._torch import (
get_torch_storage_id,
get_torch_storage_size,
load_state_dict_from_file,
load_torch_model,
save_torch_model,
save_torch_state_dict,
split_torch_state_dict_into_shards,
)

View File

@@ -0,0 +1,210 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains helpers to split tensors into shards."""
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from .. import logging
TensorT = TypeVar("TensorT")
TensorSizeFn_T = Callable[[TensorT], int]
StorageIDFn_T = Callable[[TensorT], Optional[Any]]
MAX_SHARD_SIZE = "5GB"
SIZE_UNITS = {
"TB": 10**12,
"GB": 10**9,
"MB": 10**6,
"KB": 10**3,
}
logger = logging.get_logger(__file__)
@dataclass
class StateDictSplit:
is_sharded: bool = field(init=False)
metadata: Dict[str, Any]
filename_to_tensors: Dict[str, List[str]]
tensor_to_filename: Dict[str, str]
def __post_init__(self):
self.is_sharded = len(self.filename_to_tensors) > 1
def split_state_dict_into_shards_factory(
state_dict: Dict[str, TensorT],
*,
get_storage_size: TensorSizeFn_T,
filename_pattern: str,
get_storage_id: StorageIDFn_T = lambda tensor: None,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
[6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, Tensor]`):
The state dictionary to save.
get_storage_size (`Callable[[Tensor], int]`):
A function that returns the size of a tensor when saved on disk in bytes.
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Returns:
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
"""
storage_id_to_tensors: Dict[Any, List[str]] = {}
shard_list: List[Dict[str, TensorT]] = []
current_shard: Dict[str, TensorT] = {}
current_shard_size = 0
total_size = 0
if isinstance(max_shard_size, str):
max_shard_size = parse_size_to_int(max_shard_size)
for key, tensor in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(tensor, str):
logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
continue
# If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
storage_id = get_storage_id(tensor)
if storage_id is not None:
if storage_id in storage_id_to_tensors:
# We skip this tensor for now and will reassign to correct shard later
storage_id_to_tensors[storage_id].append(key)
continue
else:
# This is the first tensor with this storage_id, we create a new entry
# in the storage_id_to_tensors dict => we will assign the shard id later
storage_id_to_tensors[storage_id] = [key]
# Compute tensor size
tensor_size = get_storage_size(tensor)
# If this tensor is bigger than the maximal size, we put it in its own shard
if tensor_size > max_shard_size:
total_size += tensor_size
shard_list.append({key: tensor})
continue
# If this tensor is going to tip up over the maximal size, we split.
# Current shard already has some tensors, we add it to the list of shards and create a new one.
if current_shard_size + tensor_size > max_shard_size:
shard_list.append(current_shard)
current_shard = {}
current_shard_size = 0
# Add the tensor to the current shard
current_shard[key] = tensor
current_shard_size += tensor_size
total_size += tensor_size
# Add the last shard
if len(current_shard) > 0:
shard_list.append(current_shard)
nb_shards = len(shard_list)
# Loop over the tensors that share the same storage and assign them together
for storage_id, keys in storage_id_to_tensors.items():
# Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
for shard in shard_list:
if keys[0] in shard:
for key in keys:
shard[key] = state_dict[key]
break
# If we only have one shard, we return it => no need to build the index
if nb_shards == 1:
filename = filename_pattern.format(suffix="")
return StateDictSplit(
metadata={"total_size": total_size},
filename_to_tensors={filename: list(state_dict.keys())},
tensor_to_filename={key: filename for key in state_dict.keys()},
)
# Now that each tensor is assigned to a shard, let's assign a filename to each shard
tensor_name_to_filename = {}
filename_to_tensors = {}
for idx, shard in enumerate(shard_list):
filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}")
for key in shard:
tensor_name_to_filename[key] = filename
filename_to_tensors[filename] = list(shard.keys())
# Build the index and return
return StateDictSplit(
metadata={"total_size": total_size},
filename_to_tensors=filename_to_tensors,
tensor_to_filename=tensor_name_to_filename,
)
def parse_size_to_int(size_as_str: str) -> int:
"""
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
Supported units are "TB", "GB", "MB", "KB".
Args:
size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> parse_size_to_int("5MB")
5000000
```
"""
size_as_str = size_as_str.strip()
# Parse unit
unit = size_as_str[-2:].upper()
if unit not in SIZE_UNITS:
raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
multiplier = SIZE_UNITS[unit]
# Parse value
try:
value = float(size_as_str[:-2].strip())
except ValueError as e:
raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
return int(value * multiplier)

View File

@@ -0,0 +1,387 @@
import json
import logging
import mmap
import os
import shutil
import zipfile
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, Tuple, Union
from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError
logger = logging.getLogger(__name__)
DDUF_ALLOWED_ENTRIES = {
# Allowed file extensions in a DDUF file
".json",
".model",
".safetensors",
".txt",
}
DDUF_FOLDER_REQUIRED_ENTRIES = {
# Each folder must contain at least one of these entries
"config.json",
"tokenizer_config.json",
"preprocessor_config.json",
"scheduler_config.json",
}
@dataclass
class DDUFEntry:
"""Object representing a file entry in a DDUF file.
See [`read_dduf_file`] for how to read a DDUF file.
Attributes:
filename (str):
The name of the file in the DDUF archive.
offset (int):
The offset of the file in the DDUF archive.
length (int):
The length of the file in the DDUF archive.
dduf_path (str):
The path to the DDUF archive (for internal use).
"""
filename: str
length: int
offset: int
dduf_path: Path = field(repr=False)
@contextmanager
def as_mmap(self) -> Generator[bytes, None, None]:
"""Open the file as a memory-mapped file.
Useful to load safetensors directly from the file.
Example:
```py
>>> import safetensors.torch
>>> with entry.as_mmap() as mm:
... tensors = safetensors.torch.load(mm)
```
"""
with self.dduf_path.open("rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm:
yield mm[self.offset : self.offset + self.length]
def read_text(self, encoding: str = "utf-8") -> str:
"""Read the file as text.
Useful for '.txt' and '.json' entries.
Example:
```py
>>> import json
>>> index = json.loads(entry.read_text())
```
"""
with self.dduf_path.open("rb") as f:
f.seek(self.offset)
return f.read(self.length).decode(encoding=encoding)
def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]:
"""
Read a DDUF file and return a dictionary of entries.
Only the metadata is read, the data is not loaded in memory.
Args:
dduf_path (`str` or `os.PathLike`):
The path to the DDUF file to read.
Returns:
`Dict[str, DDUFEntry]`:
A dictionary of [`DDUFEntry`] indexed by filename.
Raises:
- [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
Example:
```python
>>> import json
>>> import safetensors.torch
>>> from huggingface_hub import read_dduf_file
# Read DDUF metadata
>>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf")
# Returns a mapping filename <> DDUFEntry
>>> dduf_entries["model_index.json"]
DDUFEntry(filename='model_index.json', offset=66, length=587)
# Load model index as JSON
>>> json.loads(dduf_entries["model_index.json"].read_text())
{'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', ...
# Load VAE weights using safetensors
>>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm:
... state_dict = safetensors.torch.load(mm)
```
"""
entries = {}
dduf_path = Path(dduf_path)
logger.info(f"Reading DDUF file {dduf_path}")
with zipfile.ZipFile(str(dduf_path), "r") as zf:
for info in zf.infolist():
logger.debug(f"Reading entry {info.filename}")
if info.compress_type != zipfile.ZIP_STORED:
raise DDUFCorruptedFileError("Data must not be compressed in DDUF file.")
try:
_validate_dduf_entry_name(info.filename)
except DDUFInvalidEntryNameError as e:
raise DDUFCorruptedFileError(f"Invalid entry name in DDUF file: {info.filename}") from e
offset = _get_data_offset(zf, info)
entries[info.filename] = DDUFEntry(
filename=info.filename, offset=offset, length=info.file_size, dduf_path=dduf_path
)
# Consistency checks on the DDUF file
if "model_index.json" not in entries:
raise DDUFCorruptedFileError("Missing required 'model_index.json' entry in DDUF file.")
index = json.loads(entries["model_index.json"].read_text())
_validate_dduf_structure(index, entries.keys())
logger.info(f"Done reading DDUF file {dduf_path}. Found {len(entries)} entries")
return entries
def export_entries_as_dduf(
dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]]
) -> None:
"""Write a DDUF file from an iterable of entries.
This is a lower-level helper than [`export_folder_as_dduf`] that allows more flexibility when serializing data.
In particular, you don't need to save the data on disk before exporting it in the DDUF file.
Args:
dduf_path (`str` or `os.PathLike`):
The path to the DDUF file to write.
entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`):
An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content.
The filename should be the path to the file in the DDUF archive.
The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes.
Raises:
- [`DDUFExportError`]: If anything goes wrong during the export (e.g. invalid entry name, missing 'model_index.json', etc.).
Example:
```python
# Export specific files from the local disk.
>>> from huggingface_hub import export_entries_as_dduf
>>> export_entries_as_dduf(
... dduf_path="stable-diffusion-v1-4-FP16.dduf",
... entries=[ # List entries to add to the DDUF file (here, only FP16 weights)
... ("model_index.json", "path/to/model_index.json"),
... ("vae/config.json", "path/to/vae/config.json"),
... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"),
... ("text_encoder/config.json", "path/to/text_encoder/config.json"),
... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"),
... # ... add more entries here
... ]
... )
```
```python
# Export state_dicts one by one from a loaded pipeline
>>> from diffusers import DiffusionPipeline
>>> from typing import Generator, Tuple
>>> import safetensors.torch
>>> from huggingface_hub import export_entries_as_dduf
>>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
... # ... do some work with the pipeline
>>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]:
... # Build an generator that yields the entries to add to the DDUF file.
... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file.
... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time)
... yield "vae/config.json", pipe.vae.to_json_string().encode()
... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict())
... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode()
... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict())
... # ... add more entries here
>>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe))
```
"""
logger.info(f"Exporting DDUF file '{dduf_path}'")
filenames = set()
index = None
with zipfile.ZipFile(str(dduf_path), "w", zipfile.ZIP_STORED) as archive:
for filename, content in entries:
if filename in filenames:
raise DDUFExportError(f"Can't add duplicate entry: {filename}")
filenames.add(filename)
if filename == "model_index.json":
try:
index = json.loads(_load_content(content).decode())
except json.JSONDecodeError as e:
raise DDUFExportError("Failed to parse 'model_index.json'.") from e
try:
filename = _validate_dduf_entry_name(filename)
except DDUFInvalidEntryNameError as e:
raise DDUFExportError(f"Invalid entry name: {filename}") from e
logger.debug(f"Adding entry '{filename}' to DDUF file")
_dump_content_in_archive(archive, filename, content)
# Consistency checks on the DDUF file
if index is None:
raise DDUFExportError("Missing required 'model_index.json' entry in DDUF file.")
try:
_validate_dduf_structure(index, filenames)
except DDUFCorruptedFileError as e:
raise DDUFExportError("Invalid DDUF file structure.") from e
logger.info(f"Done writing DDUF file {dduf_path}")
def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union[str, os.PathLike]) -> None:
"""
Export a folder as a DDUF file.
AUses [`export_entries_as_dduf`] under the hood.
Args:
dduf_path (`str` or `os.PathLike`):
The path to the DDUF file to write.
folder_path (`str` or `os.PathLike`):
The path to the folder containing the diffusion model.
Example:
```python
>>> from huggingface_hub import export_folder_as_dduf
>>> export_folder_as_dduf(dduf_path="FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev")
```
"""
folder_path = Path(folder_path)
def _iterate_over_folder() -> Iterable[Tuple[str, Path]]:
for path in Path(folder_path).glob("**/*"):
if not path.is_file():
continue
if path.suffix not in DDUF_ALLOWED_ENTRIES:
logger.debug(f"Skipping file '{path}' (file type not allowed)")
continue
path_in_archive = path.relative_to(folder_path)
if len(path_in_archive.parts) >= 3:
logger.debug(f"Skipping file '{path}' (nested directories not allowed)")
continue
yield path_in_archive.as_posix(), path
export_entries_as_dduf(dduf_path, _iterate_over_folder())
def _dump_content_in_archive(archive: zipfile.ZipFile, filename: str, content: Union[str, os.PathLike, bytes]) -> None:
with archive.open(filename, "w", force_zip64=True) as archive_fh:
if isinstance(content, (str, Path)):
content_path = Path(content)
with content_path.open("rb") as content_fh:
shutil.copyfileobj(content_fh, archive_fh, 1024 * 1024 * 8) # type: ignore[misc]
elif isinstance(content, bytes):
archive_fh.write(content)
else:
raise DDUFExportError(f"Invalid content type for {filename}. Must be str, Path or bytes.")
def _load_content(content: Union[str, Path, bytes]) -> bytes:
"""Load the content of an entry as bytes.
Used only for small checks (not to dump content into archive).
"""
if isinstance(content, (str, Path)):
return Path(content).read_bytes()
elif isinstance(content, bytes):
return content
else:
raise DDUFExportError(f"Invalid content type. Must be str, Path or bytes. Got {type(content)}.")
def _validate_dduf_entry_name(entry_name: str) -> str:
if "." + entry_name.split(".")[-1] not in DDUF_ALLOWED_ENTRIES:
raise DDUFInvalidEntryNameError(f"File type not allowed: {entry_name}")
if "\\" in entry_name:
raise DDUFInvalidEntryNameError(f"Entry names must use UNIX separators ('/'). Got {entry_name}.")
entry_name = entry_name.strip("/")
if entry_name.count("/") > 1:
raise DDUFInvalidEntryNameError(f"DDUF only supports 1 level of directory. Got {entry_name}.")
return entry_name
def _validate_dduf_structure(index: Any, entry_names: Iterable[str]) -> None:
"""
Consistency checks on the DDUF file structure.
Rules:
- The 'model_index.json' entry is required and must contain a dictionary.
- Each folder name must correspond to an entry in 'model_index.json'.
- Each folder must contain at least a config file ('config.json', 'tokenizer_config.json', 'preprocessor_config.json', 'scheduler_config.json').
Args:
index (Any):
The content of the 'model_index.json' entry.
entry_names (Iterable[str]):
The list of entry names in the DDUF file.
Raises:
- [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
"""
if not isinstance(index, dict):
raise DDUFCorruptedFileError(f"Invalid 'model_index.json' content. Must be a dictionary. Got {type(index)}.")
dduf_folders = {entry.split("/")[0] for entry in entry_names if "/" in entry}
for folder in dduf_folders:
if folder not in index:
raise DDUFCorruptedFileError(f"Missing required entry '{folder}' in 'model_index.json'.")
if not any(f"{folder}/{required_entry}" in entry_names for required_entry in DDUF_FOLDER_REQUIRED_ENTRIES):
raise DDUFCorruptedFileError(
f"Missing required file in folder '{folder}'. Must contains at least one of {DDUF_FOLDER_REQUIRED_ENTRIES}."
)
def _get_data_offset(zf: zipfile.ZipFile, info: zipfile.ZipInfo) -> int:
"""
Calculate the data offset for a file in a ZIP archive.
Args:
zf (`zipfile.ZipFile`):
The opened ZIP file. Must be opened in read mode.
info (`zipfile.ZipInfo`):
The file info.
Returns:
int: The offset of the file data in the ZIP archive.
"""
if zf.fp is None:
raise DDUFCorruptedFileError("ZipFile object must be opened in read mode.")
# Step 1: Get the local file header offset
header_offset = info.header_offset
# Step 2: Read the local file header
zf.fp.seek(header_offset)
local_file_header = zf.fp.read(30) # Fixed-size part of the local header
if len(local_file_header) < 30:
raise DDUFCorruptedFileError("Incomplete local file header.")
# Step 3: Parse the header fields to calculate the start of file data
# Local file header: https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers
filename_len = int.from_bytes(local_file_header[26:28], "little")
extra_field_len = int.from_bytes(local_file_header[28:30], "little")
# Data offset is after the fixed header, filename, and extra fields
data_offset = header_offset + 30 + filename_len + extra_field_len
return data_offset

View File

@@ -0,0 +1,95 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains tensorflow-specific helpers."""
import math
import re
from typing import TYPE_CHECKING, Dict, Union
from .. import constants
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
if TYPE_CHECKING:
import tensorflow as tf
def split_tf_state_dict_into_shards(
state_dict: Dict[str, "tf.Tensor"],
*,
filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
[6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, Tensor]`):
The state dictionary to save.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"tf_model{suffix}.h5"`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Returns:
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
"""
return split_state_dict_into_shards_factory(
state_dict,
max_shard_size=max_shard_size,
filename_pattern=filename_pattern,
get_storage_size=get_tf_storage_size,
)
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
# Better to overestimate than underestimate.
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
def _dtype_byte_size_tf(dtype) -> float:
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608.
NOTE: why not `tensor.numpy().nbytes`?
Example:
```py
>>> _dtype_byte_size(tf.float32)
4
```
"""
import tensorflow as tf
if dtype == tf.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8

File diff suppressed because it is too large Load Diff