diff --git a/docs/usage.md b/docs/usage.md index d301ec0..735fbb1 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -392,6 +392,36 @@ Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"f `dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired. +#### Skipping the random-init shell (`low_cpu_mem_usage`) + +`dtype=` lowers peak memory but doesn't speed up the *load itself* — even with `dtype="bf16"`, GLiNER still allocates a fp32 random-initialized model shell, runs Kaiming/Xavier init over every parameter, casts the whole thing to bf16, then overwrites every value with the loaded weights. All of that init work is thrown away. + +Pass `low_cpu_mem_usage=True` to skip it: the model graph is built under `torch.device("meta")` (shape descriptors only, no allocation, no random init), the state dict is read at the target precision, and `load_state_dict(assign=True)` swaps the loaded tensors directly into the meta-shell parameter slots in one pass. + +```python +model = GLiNER.from_pretrained( + "urchade/gliner_medium-v2.1", + dtype="bf16", + low_cpu_mem_usage=True, + map_location="cuda", +) +``` + +Measured on `gliner_medium-v2.1` on an RTX 5090 (n=12 reps, Welch t-tested, OS page cache warmed): + +| path | mean load time | speedup | peak host RSS delta | +|---|---|---|---| +| baseline (cuda, bf16) | 3.16 s | 1.0× | 1361 MB | +| `low_cpu_mem_usage=True` (cuda, bf16) | **1.61 s** | **1.96×** | 1004 MB | +| baseline (cpu, bf16) | 3.30 s | 1.0× | 1597 MB | +| `low_cpu_mem_usage=True` (cpu, bf16) | **1.60 s** | **2.06×** | 1225 MB | +| baseline (cpu, fp32) | 3.04 s | 1.0× | 1598 MB | +| `low_cpu_mem_usage=True` (cpu, fp32) | **1.45 s** | **2.10×** | 170 MB | + +About **1.5 seconds saved on every cold start**, plus 23–89% lower peak host RSS depending on dtype (the fp32 case is dramatic because safetensors mmaps the on-disk file and we never copy it into anonymous memory). Loaded parameters are bit-identical to the standard path — verified across 224 parameters and 1 buffer (`position_ids`, re-materialized after assign). + +Default is `False` while the path matures — enable it explicitly when cold-start latency or peak host memory matters. `low_cpu_mem_usage` stacks with `dtype=` (use them together) and is independent of `quantize=` and `compile_torch_model=`. + #### Selective download (`variant`) `dtype=` casts in memory but the on-disk file is still fp32, so the bytes pulled from the Hub don't shrink. If a publisher uploads a half-precision variant of the file (`model.fp16.safetensors` or `model.bf16.safetensors`, following the transformers naming convention), pass `variant=` to download *only* that file: diff --git a/gliner/model.py b/gliner/model.py index 31ca2ae..5b33c4e 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -750,6 +750,55 @@ def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optio return cls._set_tokenizer_spec_tokens(tokenizer) + @staticmethod + def _materialize_meta_buffers(module: nn.Module) -> tuple: + """Re-materialize non-persistent buffers left on meta after assign-load. + + Walks the module tree and replaces any buffer still on the meta device + after a meta-init + ``load_state_dict(assign=True)`` sequence. + + Non-persistent buffers (registered with ``persistent=False``) are not + in the saved state dict, so they survive as meta tensors after the + load. This helper restores the ones we recognize: + + - ``position_ids`` — ``arange(0, max_pos).unsqueeze(0)`` (BERT/DeBERTa). + - ``token_type_ids`` — ``zeros((1, max_pos))`` (BERT-family default). + + For *unrecognized* buffers (e.g. RoPE's ``inv_freq``, where the + canonical value depends on a per-architecture ``base`` we can't + recover from the buffer alone), this returns them in + ``unrecognized`` so the caller can fall back to the standard load + path rather than ship a model with broken inference. + + Returns ``(materialized: list[str], unrecognized: list[str])``. + """ + materialized: list = [] + unrecognized: list = [] + for name, buf in list(module.named_buffers()): + if not buf.is_meta: + continue + *parents, leaf = name.split(".") + parent_mod = module + for p in parents: + parent_mod = getattr(parent_mod, p) + if leaf == "position_ids": + # Standard transformers convention: arange(0, max_pos).expand(1, -1) + value = torch.arange(0, buf.shape[-1], dtype=buf.dtype).unsqueeze(0).contiguous() + parent_mod.register_buffer(leaf, value, persistent=False) + materialized.append(name) + elif leaf == "token_type_ids": + # BERT-family default: zeros, broadcast to (1, max_pos). + value = torch.zeros(buf.shape, dtype=buf.dtype) + parent_mod.register_buffer(leaf, value, persistent=False) + materialized.append(name) + else: + # Unrecognized non-persistent buffer. We can't safely zero-fill + # because the canonical value may be load-bearing (e.g. RoPE + # ``inv_freq`` is computed from ``base ** (arange(0, dim, 2) / dim)`` + # and zeros would break attention). Surface to caller for fallback. + unrecognized.append(name) + return materialized, unrecognized + @classmethod def _load_state_dict( cls, @@ -985,6 +1034,7 @@ def from_pretrained( compile_torch_model: Optional[bool] = False, quantize: Optional[str] = None, dtype: Optional[Union[str, torch.dtype]] = None, + low_cpu_mem_usage: bool = False, variant: Optional[str] = None, load_onnx_model: Optional[bool] = False, onnx_model_file: Optional[str] = "model.onnx", @@ -1022,6 +1072,15 @@ def from_pretrained( reading, so the full fp32 copy is never materialized — peak host memory is roughly half of the default path for bf16/fp16. Prefer this over ``quantize`` for plain precision changes. + low_cpu_mem_usage: If True, build the model under + ``torch.device("meta")`` and use ``load_state_dict(assign=True)`` + to swap loaded tensors into place. Skips the random-init + compute, the fp32 random-init shell, and the post-init cast + pass — the model goes from "shape descriptor" to "loaded + weights" in one shot. Non-persistent buffers (e.g. DeBERTa's + ``position_ids``) are re-materialized after the load. + Default ``False`` for now (opt-in); enable for cold-start / + serverless deployments where every 100ms matters. variant: If set (``"fp16"`` or ``"bf16"``), prefer ``model.{variant}.safetensors`` over the default fp32 file. Best-effort: the loader probes the Hub (or local path) for the @@ -1120,24 +1179,103 @@ def from_pretrained( # still produces the requested precision after a fallback. model_file, _ = cls._resolve_model_file(model_dir, variant) - # Create model instance - instance = cls( - config, tokenizer=tokenizer, backbone_from_pretrained=False, cache_dir=cache_dir, **model_kwargs - ) - - cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) - - if torch_dtype is not None: - # Pre-cast the random-init shell so the model never exists at - # fp32 alongside the loaded state dict. ``.to(floating_dtype)`` - # only touches floating-point params/buffers. - instance.model.to(torch_dtype) - - # Load state dict (tensors cast to ``torch_dtype`` during read when set) - state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) - instance.model.load_state_dict(state_dict, strict=strict) - del state_dict - instance.model.to(map_location) + instance = None + if low_cpu_mem_usage: + # Build the model graph on meta device — no real allocation, + # no random-init compute. Shape descriptors only. + with torch.device("meta"): + meta_instance = cls( + config, + tokenizer=tokenizer, + backbone_from_pretrained=False, + cache_dir=cache_dir, + **model_kwargs, + ) + cls._resize_token_embeddings(meta_instance, config, tokenizer, resize_token_embeddings) + + # Read state dict (cast on read if dtype was set), then swap + # tensors directly into the meta-shell parameter slots. + state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) + incompat = meta_instance.model.load_state_dict(state_dict, assign=True, strict=strict) + del state_dict + + # Materialize non-persistent buffers (position_ids etc.) that + # the state dict didn't carry. + _materialized, unrecognized = cls._materialize_meta_buffers(meta_instance.model) + + # Detect parameters left on meta. With strict=True, missing + # keys would have raised at load_state_dict; with strict=False + # (the default) they don't, leaving the parameter on the meta + # device. The standard path keeps random-init values for these + # params, which is what the caller would have seen without + # low_cpu_mem_usage=True. The subsequent .to(map_location) + # would otherwise raise ``NotImplementedError: Cannot copy out + # of meta tensor`` and fail the load entirely. + meta_param_names = [n for n, p in meta_instance.model.named_parameters() if p.is_meta] + + if unrecognized or meta_param_names: + # Cases that meta-init can't safely handle: unrecognized + # non-persistent buffers (e.g. RoPE inv_freq whose base + # varies per-architecture), or parameters that the state + # dict didn't supply (strict=False + missing keys). Fall + # back to the standard load path — cost is one full + # standard load; benefit is no silent correctness bug + # and no spurious crash on .to(map_location). + if unrecognized: + short_names = sorted({n.rsplit(".", 1)[-1] for n in unrecognized}) + reason = ( + f"the model has non-persistent buffer(s) {short_names} that " + f"_materialize_meta_buffers does not recognize " + f"(e.g. RoPE inv_freq for ModernBERT)" + ) + else: + # Truncate to keep the warning readable; the missing-key + # set can be large for genuinely incomplete checkpoints. + sample = sorted(meta_param_names)[:5] + more = f" (and {len(meta_param_names) - 5} more)" if len(meta_param_names) > 5 else "" + reason = ( + f"the checkpoint is missing parameter(s) {sample}{more}; " + f"the standard load path would have kept the random-init " + f"values for these" + ) + warnings.warn( + f"low_cpu_mem_usage=True is not supported for this load: " + f"{reason}. Falling back to the standard load path so " + f"inference is correct. Pass low_cpu_mem_usage=False to " + f"silence this warning.", + UserWarning, + stacklevel=2, + ) + del meta_instance + else: + # All buffers recognized and all params materialized — + # meta path succeeded. + del incompat + meta_instance.model.to(map_location) + instance = meta_instance + + if instance is None: + # Standard path: random-init shell at fp32, optional cast, load. + # Reached when low_cpu_mem_usage=False, OR when the meta-init + # path detected unrecognized non-persistent buffers and fell + # back automatically. + instance = cls( + config, + tokenizer=tokenizer, + backbone_from_pretrained=False, + cache_dir=cache_dir, + **model_kwargs, + ) + cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) + if torch_dtype is not None: + # Pre-cast the random-init shell so the model never exists at + # fp32 alongside the loaded state dict. ``.to(floating_dtype)`` + # only touches floating-point params/buffers. + instance.model.to(torch_dtype) + state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) + instance.model.load_state_dict(state_dict, strict=strict) + del state_dict + instance.model.to(map_location) if compile_torch_model: if "cuda" in map_location: @@ -4489,6 +4627,7 @@ def from_pretrained( compile_torch_model: Optional[bool] = False, quantize: Optional[str] = None, dtype: Optional[Union[str, torch.dtype]] = None, + low_cpu_mem_usage: bool = False, variant: Optional[str] = None, load_onnx_model: Optional[bool] = False, onnx_model_file: Optional[str] = "model.onnx", @@ -4526,6 +4665,10 @@ def from_pretrained( are cast during the state-dict read so the fp32 copy is never fully materialized; prefer this over ``quantize`` for plain precision changes. + low_cpu_mem_usage: If True, build the model under + ``torch.device("meta")`` and use ``load_state_dict(assign=True)``, + skipping the random-init compute and the fp32 shell allocation. + See the base-class docstring for the full contract. variant: ``"fp16"`` / ``"bf16"`` to prefer ``model.{variant}.safetensors`` over the default fp32 file. Best-effort with graceful fallback: if the publisher uploaded @@ -4636,6 +4779,7 @@ def from_pretrained( compile_torch_model=compile_torch_model, quantize=quantize, dtype=dtype, + low_cpu_mem_usage=low_cpu_mem_usage, variant=normalized_variant, max_length=max_length, max_width=max_width, diff --git a/tests/test_quantize_and_dtype.py b/tests/test_quantize_and_dtype.py index 1636dff..9c5ee86 100644 --- a/tests/test_quantize_and_dtype.py +++ b/tests/test_quantize_and_dtype.py @@ -218,6 +218,150 @@ def test_true_raises(self): _FakeModel("cpu").quantize(True) +class TestMaterializeMetaBuffers: + """``_materialize_meta_buffers`` post-fixes non-persistent buffers that + survive a meta-init + ``load_state_dict(assign=True)`` cycle.""" + + def _module_with_meta_position_ids(self, length: int = 16) -> nn.Module: + """Build a small module mirroring DeBERTa's ``position_ids`` buffer + registration, then move it to meta to simulate post-load state.""" + module = nn.Module() + module.register_buffer( + "position_ids", + torch.arange(0, length, dtype=torch.int64).unsqueeze(0), + persistent=False, + ) + module.position_ids = module.position_ids.to("meta") + return module + + def test_position_ids_restored_to_canonical_value(self): + m = self._module_with_meta_position_ids(length=8) + assert m.position_ids.is_meta # precondition + + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == ["position_ids"] + assert unrecognized == [] + assert not m.position_ids.is_meta + assert torch.equal( + m.position_ids, + torch.arange(0, 8, dtype=torch.int64).unsqueeze(0), + ) + + def test_no_op_when_no_meta_buffers(self): + m = nn.Module() + m.register_buffer("position_ids", torch.arange(0, 4).unsqueeze(0), persistent=False) + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + assert materialized == [] + assert unrecognized == [] + + def test_nested_module_meta_buffer_restored(self): + """Buffers nested inside child modules are walked and fixed too.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer( + "position_ids", + torch.arange(0, 4, dtype=torch.int64).unsqueeze(0), + persistent=False, + ) + inner.position_ids = inner.position_ids.to("meta") + outer.add_module("embeddings", inner) + + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(outer) + + assert materialized == ["embeddings.position_ids"] + assert unrecognized == [] + assert not outer.embeddings.position_ids.is_meta + + def test_token_type_ids_restored_to_zeros(self): + """BERT-family ``token_type_ids`` is a non-persistent buffer of zeros.""" + m = nn.Module() + m.register_buffer( + "token_type_ids", + torch.zeros((1, 6), dtype=torch.int64), + persistent=False, + ) + m.token_type_ids = m.token_type_ids.to("meta") + + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == ["token_type_ids"] + assert unrecognized == [] + assert torch.equal(m.token_type_ids, torch.zeros((1, 6), dtype=torch.int64)) + + def test_unknown_buffer_returned_as_unrecognized(self): + """Unrecognized buffers are surfaced for caller-side fallback (not zero-filled).""" + m = nn.Module() + m.register_buffer( + "inv_freq", + torch.tensor([1.0, 0.5, 0.25]), + persistent=False, + ) + m.inv_freq = m.inv_freq.to("meta") + + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == [] + assert unrecognized == ["inv_freq"] + # The buffer is still meta — caller must fall back to the standard load path. + assert m.inv_freq.is_meta + + +class TestMetaParamFallbackContract: + """Codex review finding: with ``low_cpu_mem_usage=True`` and the default + ``strict=False``, ``load_state_dict(assign=True)`` may leave a parameter + on the meta device when the checkpoint is missing a key. The subsequent + ``.to(map_location)`` then raises ``NotImplementedError: Cannot copy out + of meta tensor``, whereas the standard path would have kept the + random-init value and loaded successfully. + + These tests assert the *contract* underlying the fix without spinning up + a real GLiNER model: ``load_state_dict(assign=True, strict=False)`` does + leave parameters on meta when keys are missing, and the + "scan for meta parameters after assign" check in ``from_pretrained`` + correctly identifies them. + """ + + def test_assign_load_with_missing_key_leaves_param_on_meta(self): + """The premise — without our scan, ``.to()`` would crash.""" + with torch.device("meta"): + module = nn.Linear(4, 4) + # State dict is missing the bias. + partial_sd = {"weight": torch.randn(4, 4)} + result = module.load_state_dict(partial_sd, assign=True, strict=False) + assert "bias" in result.missing_keys + # The bug: bias is still on meta. + assert module.bias.is_meta + # And .to() on a meta param raises. + with pytest.raises(NotImplementedError, match="meta"): + module.to("cpu") + + def test_meta_param_scan_finds_missing_assign_targets(self): + """The fix's scan: walk named_parameters looking for ``is_meta``, + report names so the fallback warning can surface them.""" + with torch.device("meta"): + module = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 2)) + # Provide weights but not biases. + partial_sd = { + "0.weight": torch.randn(4, 4), + "1.weight": torch.randn(2, 4), + } + module.load_state_dict(partial_sd, assign=True, strict=False) + meta_params = [n for n, p in module.named_parameters() if p.is_meta] + assert sorted(meta_params) == ["0.bias", "1.bias"] + + def test_full_assign_load_leaves_no_meta_params(self): + """Sanity: when the state dict is complete, no meta params remain.""" + with torch.device("meta"): + module = nn.Linear(4, 4) + full_sd = {"weight": torch.randn(4, 4), "bias": torch.randn(4)} + module.load_state_dict(full_sd, assign=True, strict=False) + meta_params = [n for n, p in module.named_parameters() if p.is_meta] + assert meta_params == [] + # And .to() succeeds on the materialized module. + module.to("cpu") + + class TestNormalizeVariant: """``variant=`` canonicalization for selective downloads."""