structure saas with tools
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ruff: noqa: F401
|
||||
"""Contains helpers to serialize tensors."""
|
||||
|
||||
from ._base import StateDictSplit, split_state_dict_into_shards_factory
|
||||
from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
|
||||
from ._torch import (
|
||||
get_torch_storage_id,
|
||||
get_torch_storage_size,
|
||||
load_state_dict_from_file,
|
||||
load_torch_model,
|
||||
save_torch_model,
|
||||
save_torch_state_dict,
|
||||
split_torch_state_dict_into_shards,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,210 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains helpers to split tensors into shards."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
||||
|
||||
from .. import logging
|
||||
|
||||
|
||||
TensorT = TypeVar("TensorT")
|
||||
TensorSizeFn_T = Callable[[TensorT], int]
|
||||
StorageIDFn_T = Callable[[TensorT], Optional[Any]]
|
||||
|
||||
MAX_SHARD_SIZE = "5GB"
|
||||
SIZE_UNITS = {
|
||||
"TB": 10**12,
|
||||
"GB": 10**9,
|
||||
"MB": 10**6,
|
||||
"KB": 10**3,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.get_logger(__file__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateDictSplit:
|
||||
is_sharded: bool = field(init=False)
|
||||
metadata: Dict[str, Any]
|
||||
filename_to_tensors: Dict[str, List[str]]
|
||||
tensor_to_filename: Dict[str, str]
|
||||
|
||||
def __post_init__(self):
|
||||
self.is_sharded = len(self.filename_to_tensors) > 1
|
||||
|
||||
|
||||
def split_state_dict_into_shards_factory(
|
||||
state_dict: Dict[str, TensorT],
|
||||
*,
|
||||
get_storage_size: TensorSizeFn_T,
|
||||
filename_pattern: str,
|
||||
get_storage_id: StorageIDFn_T = lambda tensor: None,
|
||||
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
||||
) -> StateDictSplit:
|
||||
"""
|
||||
Split a model state dictionary in shards so that each shard is smaller than a given size.
|
||||
|
||||
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
|
||||
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
|
||||
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
|
||||
[6+2+2GB], [6+2GB], [6GB].
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
|
||||
size greater than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
state_dict (`Dict[str, Tensor]`):
|
||||
The state dictionary to save.
|
||||
get_storage_size (`Callable[[Tensor], int]`):
|
||||
A function that returns the size of a tensor when saved on disk in bytes.
|
||||
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
|
||||
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
|
||||
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
|
||||
during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
|
||||
filename_pattern (`str`, *optional*):
|
||||
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
|
||||
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
||||
max_shard_size (`int` or `str`, *optional*):
|
||||
The maximum size of each shard, in bytes. Defaults to 5GB.
|
||||
|
||||
Returns:
|
||||
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
|
||||
"""
|
||||
storage_id_to_tensors: Dict[Any, List[str]] = {}
|
||||
|
||||
shard_list: List[Dict[str, TensorT]] = []
|
||||
current_shard: Dict[str, TensorT] = {}
|
||||
current_shard_size = 0
|
||||
total_size = 0
|
||||
|
||||
if isinstance(max_shard_size, str):
|
||||
max_shard_size = parse_size_to_int(max_shard_size)
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
# when bnb serialization is used the weights in the state dict can be strings
|
||||
# check: https://github.com/huggingface/transformers/pull/24416 for more details
|
||||
if isinstance(tensor, str):
|
||||
logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
|
||||
continue
|
||||
|
||||
# If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
|
||||
storage_id = get_storage_id(tensor)
|
||||
if storage_id is not None:
|
||||
if storage_id in storage_id_to_tensors:
|
||||
# We skip this tensor for now and will reassign to correct shard later
|
||||
storage_id_to_tensors[storage_id].append(key)
|
||||
continue
|
||||
else:
|
||||
# This is the first tensor with this storage_id, we create a new entry
|
||||
# in the storage_id_to_tensors dict => we will assign the shard id later
|
||||
storage_id_to_tensors[storage_id] = [key]
|
||||
|
||||
# Compute tensor size
|
||||
tensor_size = get_storage_size(tensor)
|
||||
|
||||
# If this tensor is bigger than the maximal size, we put it in its own shard
|
||||
if tensor_size > max_shard_size:
|
||||
total_size += tensor_size
|
||||
shard_list.append({key: tensor})
|
||||
continue
|
||||
|
||||
# If this tensor is going to tip up over the maximal size, we split.
|
||||
# Current shard already has some tensors, we add it to the list of shards and create a new one.
|
||||
if current_shard_size + tensor_size > max_shard_size:
|
||||
shard_list.append(current_shard)
|
||||
current_shard = {}
|
||||
current_shard_size = 0
|
||||
|
||||
# Add the tensor to the current shard
|
||||
current_shard[key] = tensor
|
||||
current_shard_size += tensor_size
|
||||
total_size += tensor_size
|
||||
|
||||
# Add the last shard
|
||||
if len(current_shard) > 0:
|
||||
shard_list.append(current_shard)
|
||||
nb_shards = len(shard_list)
|
||||
|
||||
# Loop over the tensors that share the same storage and assign them together
|
||||
for storage_id, keys in storage_id_to_tensors.items():
|
||||
# Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
|
||||
for shard in shard_list:
|
||||
if keys[0] in shard:
|
||||
for key in keys:
|
||||
shard[key] = state_dict[key]
|
||||
break
|
||||
|
||||
# If we only have one shard, we return it => no need to build the index
|
||||
if nb_shards == 1:
|
||||
filename = filename_pattern.format(suffix="")
|
||||
return StateDictSplit(
|
||||
metadata={"total_size": total_size},
|
||||
filename_to_tensors={filename: list(state_dict.keys())},
|
||||
tensor_to_filename={key: filename for key in state_dict.keys()},
|
||||
)
|
||||
|
||||
# Now that each tensor is assigned to a shard, let's assign a filename to each shard
|
||||
tensor_name_to_filename = {}
|
||||
filename_to_tensors = {}
|
||||
for idx, shard in enumerate(shard_list):
|
||||
filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}")
|
||||
for key in shard:
|
||||
tensor_name_to_filename[key] = filename
|
||||
filename_to_tensors[filename] = list(shard.keys())
|
||||
|
||||
# Build the index and return
|
||||
return StateDictSplit(
|
||||
metadata={"total_size": total_size},
|
||||
filename_to_tensors=filename_to_tensors,
|
||||
tensor_to_filename=tensor_name_to_filename,
|
||||
)
|
||||
|
||||
|
||||
def parse_size_to_int(size_as_str: str) -> int:
|
||||
"""
|
||||
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
|
||||
|
||||
Supported units are "TB", "GB", "MB", "KB".
|
||||
|
||||
Args:
|
||||
size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> parse_size_to_int("5MB")
|
||||
5000000
|
||||
```
|
||||
"""
|
||||
size_as_str = size_as_str.strip()
|
||||
|
||||
# Parse unit
|
||||
unit = size_as_str[-2:].upper()
|
||||
if unit not in SIZE_UNITS:
|
||||
raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
|
||||
multiplier = SIZE_UNITS[unit]
|
||||
|
||||
# Parse value
|
||||
try:
|
||||
value = float(size_as_str[:-2].strip())
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
|
||||
|
||||
return int(value * multiplier)
|
||||
@@ -0,0 +1,387 @@
|
||||
import json
|
||||
import logging
|
||||
import mmap
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Iterable, Tuple, Union
|
||||
|
||||
from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DDUF_ALLOWED_ENTRIES = {
|
||||
# Allowed file extensions in a DDUF file
|
||||
".json",
|
||||
".model",
|
||||
".safetensors",
|
||||
".txt",
|
||||
}
|
||||
|
||||
DDUF_FOLDER_REQUIRED_ENTRIES = {
|
||||
# Each folder must contain at least one of these entries
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"scheduler_config.json",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDUFEntry:
|
||||
"""Object representing a file entry in a DDUF file.
|
||||
|
||||
See [`read_dduf_file`] for how to read a DDUF file.
|
||||
|
||||
Attributes:
|
||||
filename (str):
|
||||
The name of the file in the DDUF archive.
|
||||
offset (int):
|
||||
The offset of the file in the DDUF archive.
|
||||
length (int):
|
||||
The length of the file in the DDUF archive.
|
||||
dduf_path (str):
|
||||
The path to the DDUF archive (for internal use).
|
||||
"""
|
||||
|
||||
filename: str
|
||||
length: int
|
||||
offset: int
|
||||
|
||||
dduf_path: Path = field(repr=False)
|
||||
|
||||
@contextmanager
|
||||
def as_mmap(self) -> Generator[bytes, None, None]:
|
||||
"""Open the file as a memory-mapped file.
|
||||
|
||||
Useful to load safetensors directly from the file.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> import safetensors.torch
|
||||
>>> with entry.as_mmap() as mm:
|
||||
... tensors = safetensors.torch.load(mm)
|
||||
```
|
||||
"""
|
||||
with self.dduf_path.open("rb") as f:
|
||||
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm:
|
||||
yield mm[self.offset : self.offset + self.length]
|
||||
|
||||
def read_text(self, encoding: str = "utf-8") -> str:
|
||||
"""Read the file as text.
|
||||
|
||||
Useful for '.txt' and '.json' entries.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> import json
|
||||
>>> index = json.loads(entry.read_text())
|
||||
```
|
||||
"""
|
||||
with self.dduf_path.open("rb") as f:
|
||||
f.seek(self.offset)
|
||||
return f.read(self.length).decode(encoding=encoding)
|
||||
|
||||
|
||||
def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]:
|
||||
"""
|
||||
Read a DDUF file and return a dictionary of entries.
|
||||
|
||||
Only the metadata is read, the data is not loaded in memory.
|
||||
|
||||
Args:
|
||||
dduf_path (`str` or `os.PathLike`):
|
||||
The path to the DDUF file to read.
|
||||
|
||||
Returns:
|
||||
`Dict[str, DDUFEntry]`:
|
||||
A dictionary of [`DDUFEntry`] indexed by filename.
|
||||
|
||||
Raises:
|
||||
- [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import json
|
||||
>>> import safetensors.torch
|
||||
>>> from huggingface_hub import read_dduf_file
|
||||
|
||||
# Read DDUF metadata
|
||||
>>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf")
|
||||
|
||||
# Returns a mapping filename <> DDUFEntry
|
||||
>>> dduf_entries["model_index.json"]
|
||||
DDUFEntry(filename='model_index.json', offset=66, length=587)
|
||||
|
||||
# Load model index as JSON
|
||||
>>> json.loads(dduf_entries["model_index.json"].read_text())
|
||||
{'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', ...
|
||||
|
||||
# Load VAE weights using safetensors
|
||||
>>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm:
|
||||
... state_dict = safetensors.torch.load(mm)
|
||||
```
|
||||
"""
|
||||
entries = {}
|
||||
dduf_path = Path(dduf_path)
|
||||
logger.info(f"Reading DDUF file {dduf_path}")
|
||||
with zipfile.ZipFile(str(dduf_path), "r") as zf:
|
||||
for info in zf.infolist():
|
||||
logger.debug(f"Reading entry {info.filename}")
|
||||
if info.compress_type != zipfile.ZIP_STORED:
|
||||
raise DDUFCorruptedFileError("Data must not be compressed in DDUF file.")
|
||||
|
||||
try:
|
||||
_validate_dduf_entry_name(info.filename)
|
||||
except DDUFInvalidEntryNameError as e:
|
||||
raise DDUFCorruptedFileError(f"Invalid entry name in DDUF file: {info.filename}") from e
|
||||
|
||||
offset = _get_data_offset(zf, info)
|
||||
|
||||
entries[info.filename] = DDUFEntry(
|
||||
filename=info.filename, offset=offset, length=info.file_size, dduf_path=dduf_path
|
||||
)
|
||||
|
||||
# Consistency checks on the DDUF file
|
||||
if "model_index.json" not in entries:
|
||||
raise DDUFCorruptedFileError("Missing required 'model_index.json' entry in DDUF file.")
|
||||
index = json.loads(entries["model_index.json"].read_text())
|
||||
_validate_dduf_structure(index, entries.keys())
|
||||
|
||||
logger.info(f"Done reading DDUF file {dduf_path}. Found {len(entries)} entries")
|
||||
return entries
|
||||
|
||||
|
||||
def export_entries_as_dduf(
|
||||
dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]]
|
||||
) -> None:
|
||||
"""Write a DDUF file from an iterable of entries.
|
||||
|
||||
This is a lower-level helper than [`export_folder_as_dduf`] that allows more flexibility when serializing data.
|
||||
In particular, you don't need to save the data on disk before exporting it in the DDUF file.
|
||||
|
||||
Args:
|
||||
dduf_path (`str` or `os.PathLike`):
|
||||
The path to the DDUF file to write.
|
||||
entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`):
|
||||
An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content.
|
||||
The filename should be the path to the file in the DDUF archive.
|
||||
The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes.
|
||||
|
||||
Raises:
|
||||
- [`DDUFExportError`]: If anything goes wrong during the export (e.g. invalid entry name, missing 'model_index.json', etc.).
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Export specific files from the local disk.
|
||||
>>> from huggingface_hub import export_entries_as_dduf
|
||||
>>> export_entries_as_dduf(
|
||||
... dduf_path="stable-diffusion-v1-4-FP16.dduf",
|
||||
... entries=[ # List entries to add to the DDUF file (here, only FP16 weights)
|
||||
... ("model_index.json", "path/to/model_index.json"),
|
||||
... ("vae/config.json", "path/to/vae/config.json"),
|
||||
... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"),
|
||||
... ("text_encoder/config.json", "path/to/text_encoder/config.json"),
|
||||
... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"),
|
||||
... # ... add more entries here
|
||||
... ]
|
||||
... )
|
||||
```
|
||||
|
||||
```python
|
||||
# Export state_dicts one by one from a loaded pipeline
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from typing import Generator, Tuple
|
||||
>>> import safetensors.torch
|
||||
>>> from huggingface_hub import export_entries_as_dduf
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
... # ... do some work with the pipeline
|
||||
|
||||
>>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]:
|
||||
... # Build an generator that yields the entries to add to the DDUF file.
|
||||
... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file.
|
||||
... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time)
|
||||
... yield "vae/config.json", pipe.vae.to_json_string().encode()
|
||||
... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict())
|
||||
... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode()
|
||||
... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict())
|
||||
... # ... add more entries here
|
||||
|
||||
>>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe))
|
||||
```
|
||||
"""
|
||||
logger.info(f"Exporting DDUF file '{dduf_path}'")
|
||||
filenames = set()
|
||||
index = None
|
||||
with zipfile.ZipFile(str(dduf_path), "w", zipfile.ZIP_STORED) as archive:
|
||||
for filename, content in entries:
|
||||
if filename in filenames:
|
||||
raise DDUFExportError(f"Can't add duplicate entry: {filename}")
|
||||
filenames.add(filename)
|
||||
|
||||
if filename == "model_index.json":
|
||||
try:
|
||||
index = json.loads(_load_content(content).decode())
|
||||
except json.JSONDecodeError as e:
|
||||
raise DDUFExportError("Failed to parse 'model_index.json'.") from e
|
||||
|
||||
try:
|
||||
filename = _validate_dduf_entry_name(filename)
|
||||
except DDUFInvalidEntryNameError as e:
|
||||
raise DDUFExportError(f"Invalid entry name: {filename}") from e
|
||||
logger.debug(f"Adding entry '{filename}' to DDUF file")
|
||||
_dump_content_in_archive(archive, filename, content)
|
||||
|
||||
# Consistency checks on the DDUF file
|
||||
if index is None:
|
||||
raise DDUFExportError("Missing required 'model_index.json' entry in DDUF file.")
|
||||
try:
|
||||
_validate_dduf_structure(index, filenames)
|
||||
except DDUFCorruptedFileError as e:
|
||||
raise DDUFExportError("Invalid DDUF file structure.") from e
|
||||
|
||||
logger.info(f"Done writing DDUF file {dduf_path}")
|
||||
|
||||
|
||||
def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Export a folder as a DDUF file.
|
||||
|
||||
AUses [`export_entries_as_dduf`] under the hood.
|
||||
|
||||
Args:
|
||||
dduf_path (`str` or `os.PathLike`):
|
||||
The path to the DDUF file to write.
|
||||
folder_path (`str` or `os.PathLike`):
|
||||
The path to the folder containing the diffusion model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from huggingface_hub import export_folder_as_dduf
|
||||
>>> export_folder_as_dduf(dduf_path="FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev")
|
||||
```
|
||||
"""
|
||||
folder_path = Path(folder_path)
|
||||
|
||||
def _iterate_over_folder() -> Iterable[Tuple[str, Path]]:
|
||||
for path in Path(folder_path).glob("**/*"):
|
||||
if not path.is_file():
|
||||
continue
|
||||
if path.suffix not in DDUF_ALLOWED_ENTRIES:
|
||||
logger.debug(f"Skipping file '{path}' (file type not allowed)")
|
||||
continue
|
||||
path_in_archive = path.relative_to(folder_path)
|
||||
if len(path_in_archive.parts) >= 3:
|
||||
logger.debug(f"Skipping file '{path}' (nested directories not allowed)")
|
||||
continue
|
||||
yield path_in_archive.as_posix(), path
|
||||
|
||||
export_entries_as_dduf(dduf_path, _iterate_over_folder())
|
||||
|
||||
|
||||
def _dump_content_in_archive(archive: zipfile.ZipFile, filename: str, content: Union[str, os.PathLike, bytes]) -> None:
|
||||
with archive.open(filename, "w", force_zip64=True) as archive_fh:
|
||||
if isinstance(content, (str, Path)):
|
||||
content_path = Path(content)
|
||||
with content_path.open("rb") as content_fh:
|
||||
shutil.copyfileobj(content_fh, archive_fh, 1024 * 1024 * 8) # type: ignore[misc]
|
||||
elif isinstance(content, bytes):
|
||||
archive_fh.write(content)
|
||||
else:
|
||||
raise DDUFExportError(f"Invalid content type for {filename}. Must be str, Path or bytes.")
|
||||
|
||||
|
||||
def _load_content(content: Union[str, Path, bytes]) -> bytes:
|
||||
"""Load the content of an entry as bytes.
|
||||
|
||||
Used only for small checks (not to dump content into archive).
|
||||
"""
|
||||
if isinstance(content, (str, Path)):
|
||||
return Path(content).read_bytes()
|
||||
elif isinstance(content, bytes):
|
||||
return content
|
||||
else:
|
||||
raise DDUFExportError(f"Invalid content type. Must be str, Path or bytes. Got {type(content)}.")
|
||||
|
||||
|
||||
def _validate_dduf_entry_name(entry_name: str) -> str:
|
||||
if "." + entry_name.split(".")[-1] not in DDUF_ALLOWED_ENTRIES:
|
||||
raise DDUFInvalidEntryNameError(f"File type not allowed: {entry_name}")
|
||||
if "\\" in entry_name:
|
||||
raise DDUFInvalidEntryNameError(f"Entry names must use UNIX separators ('/'). Got {entry_name}.")
|
||||
entry_name = entry_name.strip("/")
|
||||
if entry_name.count("/") > 1:
|
||||
raise DDUFInvalidEntryNameError(f"DDUF only supports 1 level of directory. Got {entry_name}.")
|
||||
return entry_name
|
||||
|
||||
|
||||
def _validate_dduf_structure(index: Any, entry_names: Iterable[str]) -> None:
|
||||
"""
|
||||
Consistency checks on the DDUF file structure.
|
||||
|
||||
Rules:
|
||||
- The 'model_index.json' entry is required and must contain a dictionary.
|
||||
- Each folder name must correspond to an entry in 'model_index.json'.
|
||||
- Each folder must contain at least a config file ('config.json', 'tokenizer_config.json', 'preprocessor_config.json', 'scheduler_config.json').
|
||||
|
||||
Args:
|
||||
index (Any):
|
||||
The content of the 'model_index.json' entry.
|
||||
entry_names (Iterable[str]):
|
||||
The list of entry names in the DDUF file.
|
||||
|
||||
Raises:
|
||||
- [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
|
||||
"""
|
||||
if not isinstance(index, dict):
|
||||
raise DDUFCorruptedFileError(f"Invalid 'model_index.json' content. Must be a dictionary. Got {type(index)}.")
|
||||
|
||||
dduf_folders = {entry.split("/")[0] for entry in entry_names if "/" in entry}
|
||||
for folder in dduf_folders:
|
||||
if folder not in index:
|
||||
raise DDUFCorruptedFileError(f"Missing required entry '{folder}' in 'model_index.json'.")
|
||||
if not any(f"{folder}/{required_entry}" in entry_names for required_entry in DDUF_FOLDER_REQUIRED_ENTRIES):
|
||||
raise DDUFCorruptedFileError(
|
||||
f"Missing required file in folder '{folder}'. Must contains at least one of {DDUF_FOLDER_REQUIRED_ENTRIES}."
|
||||
)
|
||||
|
||||
|
||||
def _get_data_offset(zf: zipfile.ZipFile, info: zipfile.ZipInfo) -> int:
|
||||
"""
|
||||
Calculate the data offset for a file in a ZIP archive.
|
||||
|
||||
Args:
|
||||
zf (`zipfile.ZipFile`):
|
||||
The opened ZIP file. Must be opened in read mode.
|
||||
info (`zipfile.ZipInfo`):
|
||||
The file info.
|
||||
|
||||
Returns:
|
||||
int: The offset of the file data in the ZIP archive.
|
||||
"""
|
||||
if zf.fp is None:
|
||||
raise DDUFCorruptedFileError("ZipFile object must be opened in read mode.")
|
||||
|
||||
# Step 1: Get the local file header offset
|
||||
header_offset = info.header_offset
|
||||
|
||||
# Step 2: Read the local file header
|
||||
zf.fp.seek(header_offset)
|
||||
local_file_header = zf.fp.read(30) # Fixed-size part of the local header
|
||||
|
||||
if len(local_file_header) < 30:
|
||||
raise DDUFCorruptedFileError("Incomplete local file header.")
|
||||
|
||||
# Step 3: Parse the header fields to calculate the start of file data
|
||||
# Local file header: https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers
|
||||
filename_len = int.from_bytes(local_file_header[26:28], "little")
|
||||
extra_field_len = int.from_bytes(local_file_header[28:30], "little")
|
||||
|
||||
# Data offset is after the fixed header, filename, and extra fields
|
||||
data_offset = header_offset + 30 + filename_len + extra_field_len
|
||||
|
||||
return data_offset
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains tensorflow-specific helpers."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .. import constants
|
||||
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def split_tf_state_dict_into_shards(
|
||||
state_dict: Dict[str, "tf.Tensor"],
|
||||
*,
|
||||
filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
|
||||
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
||||
) -> StateDictSplit:
|
||||
"""
|
||||
Split a model state dictionary in shards so that each shard is smaller than a given size.
|
||||
|
||||
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
|
||||
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
|
||||
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
|
||||
[6+2+2GB], [6+2GB], [6GB].
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
|
||||
size greater than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
state_dict (`Dict[str, Tensor]`):
|
||||
The state dictionary to save.
|
||||
filename_pattern (`str`, *optional*):
|
||||
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
|
||||
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
||||
Defaults to `"tf_model{suffix}.h5"`.
|
||||
max_shard_size (`int` or `str`, *optional*):
|
||||
The maximum size of each shard, in bytes. Defaults to 5GB.
|
||||
|
||||
Returns:
|
||||
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
|
||||
"""
|
||||
return split_state_dict_into_shards_factory(
|
||||
state_dict,
|
||||
max_shard_size=max_shard_size,
|
||||
filename_pattern=filename_pattern,
|
||||
get_storage_size=get_tf_storage_size,
|
||||
)
|
||||
|
||||
|
||||
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
|
||||
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
|
||||
# Better to overestimate than underestimate.
|
||||
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
|
||||
|
||||
|
||||
def _dtype_byte_size_tf(dtype) -> float:
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608.
|
||||
NOTE: why not `tensor.numpy().nbytes`?
|
||||
Example:
|
||||
```py
|
||||
>>> _dtype_byte_size(tf.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
if dtype == tf.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user