Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,36 @@ Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"f

`dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired.

#### Skipping the random-init shell (`low_cpu_mem_usage`)

`dtype=` lowers peak memory but doesn't speed up the *load itself* — even with `dtype="bf16"`, GLiNER still allocates a fp32 random-initialized model shell, runs Kaiming/Xavier init over every parameter, casts the whole thing to bf16, then overwrites every value with the loaded weights. All of that init work is thrown away.

Pass `low_cpu_mem_usage=True` to skip it: the model graph is built under `torch.device("meta")` (shape descriptors only, no allocation, no random init), the state dict is read at the target precision, and `load_state_dict(assign=True)` swaps the loaded tensors directly into the meta-shell parameter slots in one pass.

```python
model = GLiNER.from_pretrained(
"urchade/gliner_medium-v2.1",
dtype="bf16",
low_cpu_mem_usage=True,
map_location="cuda",
)
```

Measured on `gliner_medium-v2.1` on an RTX 5090 (n=12 reps, Welch t-tested, OS page cache warmed):

| path | mean load time | speedup | peak host RSS delta |
|---|---|---|---|
| baseline (cuda, bf16) | 3.16 s | 1.0× | 1361 MB |
| `low_cpu_mem_usage=True` (cuda, bf16) | **1.61 s** | **1.96×** | 1004 MB |
| baseline (cpu, bf16) | 3.30 s | 1.0× | 1597 MB |
| `low_cpu_mem_usage=True` (cpu, bf16) | **1.60 s** | **2.06×** | 1225 MB |
| baseline (cpu, fp32) | 3.04 s | 1.0× | 1598 MB |
| `low_cpu_mem_usage=True` (cpu, fp32) | **1.45 s** | **2.10×** | 170 MB |

About **1.5 seconds saved on every cold start**, plus 23–89% lower peak host RSS depending on dtype (the fp32 case is dramatic because safetensors mmaps the on-disk file and we never copy it into anonymous memory). Loaded parameters are bit-identical to the standard path — verified across 224 parameters and 1 buffer (`position_ids`, re-materialized after assign).

Default is `False` while the path matures — enable it explicitly when cold-start latency or peak host memory matters. `low_cpu_mem_usage` stacks with `dtype=` (use them together) and is independent of `quantize=` and `compile_torch_model=`.

#### Selective download (`variant`)

`dtype=` casts in memory but the on-disk file is still fp32, so the bytes pulled from the Hub don't shrink. If a publisher uploads a half-precision variant of the file (`model.fp16.safetensors` or `model.bf16.safetensors`, following the transformers naming convention), pass `variant=` to download *only* that file:
Expand Down
180 changes: 162 additions & 18 deletions gliner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,55 @@ def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optio

return cls._set_tokenizer_spec_tokens(tokenizer)

@staticmethod
def _materialize_meta_buffers(module: nn.Module) -> tuple:
"""Re-materialize non-persistent buffers left on meta after assign-load.

Walks the module tree and replaces any buffer still on the meta device
after a meta-init + ``load_state_dict(assign=True)`` sequence.

Non-persistent buffers (registered with ``persistent=False``) are not
in the saved state dict, so they survive as meta tensors after the
load. This helper restores the ones we recognize:

- ``position_ids`` — ``arange(0, max_pos).unsqueeze(0)`` (BERT/DeBERTa).
- ``token_type_ids`` — ``zeros((1, max_pos))`` (BERT-family default).

For *unrecognized* buffers (e.g. RoPE's ``inv_freq``, where the
canonical value depends on a per-architecture ``base`` we can't
recover from the buffer alone), this returns them in
``unrecognized`` so the caller can fall back to the standard load
path rather than ship a model with broken inference.

Returns ``(materialized: list[str], unrecognized: list[str])``.
"""
materialized: list = []
unrecognized: list = []
for name, buf in list(module.named_buffers()):
if not buf.is_meta:
continue
*parents, leaf = name.split(".")
parent_mod = module
for p in parents:
parent_mod = getattr(parent_mod, p)
if leaf == "position_ids":
# Standard transformers convention: arange(0, max_pos).expand(1, -1)
value = torch.arange(0, buf.shape[-1], dtype=buf.dtype).unsqueeze(0).contiguous()
parent_mod.register_buffer(leaf, value, persistent=False)
materialized.append(name)
elif leaf == "token_type_ids":
# BERT-family default: zeros, broadcast to (1, max_pos).
value = torch.zeros(buf.shape, dtype=buf.dtype)
parent_mod.register_buffer(leaf, value, persistent=False)
materialized.append(name)
else:
# Unrecognized non-persistent buffer. We can't safely zero-fill
# because the canonical value may be load-bearing (e.g. RoPE
# ``inv_freq`` is computed from ``base ** (arange(0, dim, 2) / dim)``
# and zeros would break attention). Surface to caller for fallback.
unrecognized.append(name)
return materialized, unrecognized

@classmethod
def _load_state_dict(
cls,
Expand Down Expand Up @@ -985,6 +1034,7 @@ def from_pretrained(
compile_torch_model: Optional[bool] = False,
quantize: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
low_cpu_mem_usage: bool = False,
variant: Optional[str] = None,
load_onnx_model: Optional[bool] = False,
onnx_model_file: Optional[str] = "model.onnx",
Expand Down Expand Up @@ -1022,6 +1072,15 @@ def from_pretrained(
reading, so the full fp32 copy is never materialized — peak
host memory is roughly half of the default path for bf16/fp16.
Prefer this over ``quantize`` for plain precision changes.
low_cpu_mem_usage: If True, build the model under
``torch.device("meta")`` and use ``load_state_dict(assign=True)``
to swap loaded tensors into place. Skips the random-init
compute, the fp32 random-init shell, and the post-init cast
pass — the model goes from "shape descriptor" to "loaded
weights" in one shot. Non-persistent buffers (e.g. DeBERTa's
``position_ids``) are re-materialized after the load.
Default ``False`` for now (opt-in); enable for cold-start /
serverless deployments where every 100ms matters.
variant: If set (``"fp16"`` or ``"bf16"``), prefer
``model.{variant}.safetensors`` over the default fp32 file.
Best-effort: the loader probes the Hub (or local path) for the
Expand Down Expand Up @@ -1120,24 +1179,103 @@ def from_pretrained(
# still produces the requested precision after a fallback.
model_file, _ = cls._resolve_model_file(model_dir, variant)

# Create model instance
instance = cls(
config, tokenizer=tokenizer, backbone_from_pretrained=False, cache_dir=cache_dir, **model_kwargs
)

cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings)

if torch_dtype is not None:
# Pre-cast the random-init shell so the model never exists at
# fp32 alongside the loaded state dict. ``.to(floating_dtype)``
# only touches floating-point params/buffers.
instance.model.to(torch_dtype)

# Load state dict (tensors cast to ``torch_dtype`` during read when set)
state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype)
instance.model.load_state_dict(state_dict, strict=strict)
del state_dict
instance.model.to(map_location)
instance = None
if low_cpu_mem_usage:
# Build the model graph on meta device — no real allocation,
# no random-init compute. Shape descriptors only.
with torch.device("meta"):
meta_instance = cls(
config,
tokenizer=tokenizer,
backbone_from_pretrained=False,
cache_dir=cache_dir,
**model_kwargs,
)
cls._resize_token_embeddings(meta_instance, config, tokenizer, resize_token_embeddings)

# Read state dict (cast on read if dtype was set), then swap
# tensors directly into the meta-shell parameter slots.
state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype)
incompat = meta_instance.model.load_state_dict(state_dict, assign=True, strict=strict)
del state_dict

# Materialize non-persistent buffers (position_ids etc.) that
# the state dict didn't carry.
_materialized, unrecognized = cls._materialize_meta_buffers(meta_instance.model)

# Detect parameters left on meta. With strict=True, missing
# keys would have raised at load_state_dict; with strict=False
# (the default) they don't, leaving the parameter on the meta
# device. The standard path keeps random-init values for these
# params, which is what the caller would have seen without
# low_cpu_mem_usage=True. The subsequent .to(map_location)
# would otherwise raise ``NotImplementedError: Cannot copy out
# of meta tensor`` and fail the load entirely.
meta_param_names = [n for n, p in meta_instance.model.named_parameters() if p.is_meta]

if unrecognized or meta_param_names:
# Cases that meta-init can't safely handle: unrecognized
# non-persistent buffers (e.g. RoPE inv_freq whose base
# varies per-architecture), or parameters that the state
# dict didn't supply (strict=False + missing keys). Fall
# back to the standard load path — cost is one full
# standard load; benefit is no silent correctness bug
# and no spurious crash on .to(map_location).
if unrecognized:
short_names = sorted({n.rsplit(".", 1)[-1] for n in unrecognized})
reason = (
f"the model has non-persistent buffer(s) {short_names} that "
f"_materialize_meta_buffers does not recognize "
f"(e.g. RoPE inv_freq for ModernBERT)"
)
else:
# Truncate to keep the warning readable; the missing-key
# set can be large for genuinely incomplete checkpoints.
sample = sorted(meta_param_names)[:5]
more = f" (and {len(meta_param_names) - 5} more)" if len(meta_param_names) > 5 else ""
reason = (
f"the checkpoint is missing parameter(s) {sample}{more}; "
f"the standard load path would have kept the random-init "
f"values for these"
)
warnings.warn(
f"low_cpu_mem_usage=True is not supported for this load: "
f"{reason}. Falling back to the standard load path so "
f"inference is correct. Pass low_cpu_mem_usage=False to "
f"silence this warning.",
UserWarning,
stacklevel=2,
)
del meta_instance
else:
# All buffers recognized and all params materialized —
# meta path succeeded.
del incompat
meta_instance.model.to(map_location)
instance = meta_instance

if instance is None:
# Standard path: random-init shell at fp32, optional cast, load.
# Reached when low_cpu_mem_usage=False, OR when the meta-init
# path detected unrecognized non-persistent buffers and fell
# back automatically.
instance = cls(
config,
tokenizer=tokenizer,
backbone_from_pretrained=False,
cache_dir=cache_dir,
**model_kwargs,
)
cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings)
if torch_dtype is not None:
# Pre-cast the random-init shell so the model never exists at
# fp32 alongside the loaded state dict. ``.to(floating_dtype)``
# only touches floating-point params/buffers.
instance.model.to(torch_dtype)
state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype)
instance.model.load_state_dict(state_dict, strict=strict)
del state_dict
instance.model.to(map_location)

if compile_torch_model:
if "cuda" in map_location:
Expand Down Expand Up @@ -4489,6 +4627,7 @@ def from_pretrained(
compile_torch_model: Optional[bool] = False,
quantize: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
low_cpu_mem_usage: bool = False,
variant: Optional[str] = None,
load_onnx_model: Optional[bool] = False,
onnx_model_file: Optional[str] = "model.onnx",
Expand Down Expand Up @@ -4526,6 +4665,10 @@ def from_pretrained(
are cast during the state-dict read so the fp32 copy is never
fully materialized; prefer this over ``quantize`` for plain
precision changes.
low_cpu_mem_usage: If True, build the model under
``torch.device("meta")`` and use ``load_state_dict(assign=True)``,
skipping the random-init compute and the fp32 shell allocation.
See the base-class docstring for the full contract.
variant: ``"fp16"`` / ``"bf16"`` to prefer
``model.{variant}.safetensors`` over the default fp32 file.
Best-effort with graceful fallback: if the publisher uploaded
Expand Down Expand Up @@ -4636,6 +4779,7 @@ def from_pretrained(
compile_torch_model=compile_torch_model,
quantize=quantize,
dtype=dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=normalized_variant,
max_length=max_length,
max_width=max_width,
Expand Down
Loading
Loading