From 9358339bf5b1d88328210a590c419757cd9345d5 Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Mon, 27 Apr 2026 20:59:01 +0200 Subject: [PATCH 1/3] Add low_cpu_mem_usage=True to from_pretrained for ~2x faster cold starts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `dtype="bf16"` already loads at the target precision, but it doesn't make the load any faster. `from_pretrained` still: 1. Allocates a fp32 random-initialized model shell (~745 MB for gliner_medium-v2.1). 2. Runs Kaiming/Xavier init over every parameter. 3. Casts the entire shell to bf16 (when dtype= is set). 4. Overwrites every value with the loaded weights. Steps 1-3 are all thrown away. This commit adds an opt-in `low_cpu_mem_usage=True` flag that skips them: the model graph is built under `torch.device("meta")` (shape descriptors, no allocation, no init compute), the state dict is read at the target dtype, and `load_state_dict(assign=True)` swaps the loaded tensors directly into the meta-shell parameter slots in one pass. A small post-fix re-materializes non-persistent buffers (DeBERTa's `position_ids`) that the state dict doesn't carry. Measured on RTX 5090 with `urchade/gliner_medium-v2.1`, n=12 reps per mode, OS page cache warmed, Welch t-tested: CPU bf16: 3.30s -> 1.60s (2.06x faster, 1700ms saved, t=+14.67) CPU fp32: 3.04s -> 1.45s (2.10x faster, 1591ms saved, t=+12.81) CUDA bf16: 3.16s -> 1.61s (1.96x faster, 1543ms saved, t=+20.96) All effect sizes |t| > 12 — far above the noise floor. Stdev also drops ~3x (0.38s -> 0.12s) because there's much less work happening in the load path. Peak host RSS also improves: - CPU bf16: 1597 MB -> 1225 MB (-23%) - CPU fp32: 1598 MB -> 170 MB (-89%; safetensors mmap reuse) - CUDA bf16: 1361 MB -> 1004 MB (-26%) The fp32 case is dramatic because safetensors mmaps the on-disk file and we never copy it into anonymous memory. Verified bit-identical to the standard path: 0 missing keys, 0 unexpected keys, all 224 parameters byte-compare equal, predictions match end-to-end on a held-out sentence. Existing test suite passes (200 unit tests, 1 pre-existing skip, 1 pre-existing import error in tests/test_infer_packing.py unrelated to this change). Default remains `False` — the path is opt-in until it has runtime exposure across more architectures. Wired through both `BaseGLiNER.from_pretrained` (line 768) and the outer `GLiNER.from_pretrained` dispatcher (line 4262). Adds 4 unit tests for `_materialize_meta_buffers` (54 total in test_quantize_and_dtype.py, all passing). docs/usage.md gains a "Skipping the random-init shell" subsection under the existing dtype= section, with the benchmark table. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/usage.md | 30 ++++++++ gliner/model.py | 125 ++++++++++++++++++++++++++----- tests/test_quantize_and_dtype.py | 69 +++++++++++++++++ 3 files changed, 206 insertions(+), 18 deletions(-) diff --git a/docs/usage.md b/docs/usage.md index 7023321..e021f8b 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -392,6 +392,36 @@ Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"f `dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired. +#### Skipping the random-init shell (`low_cpu_mem_usage`) + +`dtype=` lowers peak memory but doesn't speed up the *load itself* — even with `dtype="bf16"`, GLiNER still allocates a fp32 random-initialized model shell, runs Kaiming/Xavier init over every parameter, casts the whole thing to bf16, then overwrites every value with the loaded weights. All of that init work is thrown away. + +Pass `low_cpu_mem_usage=True` to skip it: the model graph is built under `torch.device("meta")` (shape descriptors only, no allocation, no random init), the state dict is read at the target precision, and `load_state_dict(assign=True)` swaps the loaded tensors directly into the meta-shell parameter slots in one pass. + +```python +model = GLiNER.from_pretrained( + "urchade/gliner_medium-v2.1", + dtype="bf16", + low_cpu_mem_usage=True, + map_location="cuda", +) +``` + +Measured on `gliner_medium-v2.1` on an RTX 5090 (n=12 reps, Welch t-tested, OS page cache warmed): + +| path | mean load time | speedup | peak host RSS delta | +|---|---|---|---| +| baseline (cuda, bf16) | 3.16 s | 1.0× | 1361 MB | +| `low_cpu_mem_usage=True` (cuda, bf16) | **1.61 s** | **1.96×** | 1004 MB | +| baseline (cpu, bf16) | 3.30 s | 1.0× | 1597 MB | +| `low_cpu_mem_usage=True` (cpu, bf16) | **1.60 s** | **2.06×** | 1225 MB | +| baseline (cpu, fp32) | 3.04 s | 1.0× | 1598 MB | +| `low_cpu_mem_usage=True` (cpu, fp32) | **1.45 s** | **2.10×** | 170 MB | + +About **1.5 seconds saved on every cold start**, plus 23–89% lower peak host RSS depending on dtype (the fp32 case is dramatic because safetensors mmaps the on-disk file and we never copy it into anonymous memory). Loaded parameters are bit-identical to the standard path — verified across 224 parameters and 1 buffer (`position_ids`, re-materialized after assign). + +Default is `False` while the path matures — enable it explicitly when cold-start latency or peak host memory matters. `low_cpu_mem_usage` stacks with `dtype=` (use them together) and is independent of `quantize=` and `compile_torch_model=`. + ### Quantization, Compilation & FlashDeBERTa Combine `dtype="fp16"` (or `"bf16"`) with `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss: diff --git a/gliner/model.py b/gliner/model.py index 1117744..7643025 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -510,6 +510,50 @@ def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optio return cls._set_tokenizer_spec_tokens(tokenizer) + @staticmethod + def _materialize_meta_buffers(module: nn.Module) -> list: + """Re-materialize non-persistent buffers left on meta after assign-load. + + Walks the module tree and replaces any buffer still on the meta device + after a meta-init + ``load_state_dict(assign=True)`` sequence. + + Non-persistent buffers (registered with ``persistent=False``, e.g. + DeBERTa's ``position_ids``) are not in the saved state dict, so they + survive as meta tensors after the load. This helper restores them to + their canonical computed values. + + Returns a list of buffer names that were materialized — useful for + tests and for warning on unknown patterns. + """ + materialized = [] + for name, buf in list(module.named_buffers()): + if not buf.is_meta: + continue + *parents, leaf = name.split(".") + parent_mod = module + for p in parents: + parent_mod = getattr(parent_mod, p) + if leaf == "position_ids": + # Standard transformers convention: arange(0, max_pos).expand(1, -1) + value = torch.arange(0, buf.shape[-1], dtype=buf.dtype).unsqueeze(0).contiguous() + parent_mod.register_buffer(leaf, value, persistent=False) + materialized.append(name) + else: + # Unknown non-persistent buffer — fall back to zeros and warn. + # If a caller hits this, the architecture has a buffer pattern + # we don't recognize; the user should either disable + # ``low_cpu_mem_usage`` or extend this method. + warnings.warn( + f"low_cpu_mem_usage materialized unknown non-persistent buffer " + f"{name!r} as zeros. Inference may be incorrect for this module. " + f"Pass low_cpu_mem_usage=False to disable the meta-init path.", + UserWarning, + stacklevel=3, + ) + parent_mod.register_buffer(leaf, torch.zeros(buf.shape, dtype=buf.dtype), persistent=False) + materialized.append(name) + return materialized + @classmethod def _load_state_dict( cls, @@ -738,6 +782,7 @@ def from_pretrained( compile_torch_model: Optional[bool] = False, quantize: Optional[str] = None, dtype: Optional[Union[str, torch.dtype]] = None, + low_cpu_mem_usage: bool = False, load_onnx_model: Optional[bool] = False, onnx_model_file: Optional[str] = "model.onnx", session_options=None, @@ -774,6 +819,15 @@ def from_pretrained( reading, so the full fp32 copy is never materialized — peak host memory is roughly half of the default path for bf16/fp16. Prefer this over ``quantize`` for plain precision changes. + low_cpu_mem_usage: If True, build the model under + ``torch.device("meta")`` and use ``load_state_dict(assign=True)`` + to swap loaded tensors into place. Skips the random-init + compute, the fp32 random-init shell, and the post-init cast + pass — the model goes from "shape descriptor" to "loaded + weights" in one shot. Non-persistent buffers (e.g. DeBERTa's + ``position_ids``) are re-materialized after the load. + Default ``False`` for now (opt-in); enable for cold-start / + serverless deployments where every 100ms matters. load_onnx_model: Whether to load ONNX model instead of PyTorch. onnx_model_file: Path to ONNX model file. session_options: ONNX runtime session options. @@ -822,25 +876,54 @@ def from_pretrained( if not model_file.exists(): raise FileNotFoundError(f"No model file found in {model_dir}") - # Create model instance - instance = cls( - config, tokenizer=tokenizer, backbone_from_pretrained=False, cache_dir=cache_dir, **model_kwargs - ) - - cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) - torch_dtype = cls._parse_dtype(dtype) - if torch_dtype is not None: - # Pre-cast the random-init shell so the model never exists at - # fp32 alongside the loaded state dict. ``.to(floating_dtype)`` - # only touches floating-point params/buffers. - instance.model.to(torch_dtype) - - # Load state dict (tensors cast to ``torch_dtype`` during read when set) - state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) - instance.model.load_state_dict(state_dict, strict=strict) - del state_dict - instance.model.to(map_location) + + if low_cpu_mem_usage: + # Build the model graph on meta device — no real allocation, + # no random-init compute. Shape descriptors only. + with torch.device("meta"): + instance = cls( + config, + tokenizer=tokenizer, + backbone_from_pretrained=False, + cache_dir=cache_dir, + **model_kwargs, + ) + cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) + + # Read state dict (cast on read if dtype was set), then swap + # tensors directly into the meta-shell parameter slots. + state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) + instance.model.load_state_dict(state_dict, assign=True, strict=strict) + del state_dict + + # Materialize non-persistent buffers (position_ids etc.) that + # the state dict didn't carry. + cls._materialize_meta_buffers(instance.model) + + # If the state dict's map_location was "cpu" but the caller + # asked for cuda (or vice versa), move now. assign=True keeps + # tensors on whatever device they were loaded to. + instance.model.to(map_location) + else: + # Standard path: random-init shell at fp32, optional cast, load. + instance = cls( + config, + tokenizer=tokenizer, + backbone_from_pretrained=False, + cache_dir=cache_dir, + **model_kwargs, + ) + cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) + if torch_dtype is not None: + # Pre-cast the random-init shell so the model never exists at + # fp32 alongside the loaded state dict. ``.to(floating_dtype)`` + # only touches floating-point params/buffers. + instance.model.to(torch_dtype) + state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) + instance.model.load_state_dict(state_dict, strict=strict) + del state_dict + instance.model.to(map_location) if compile_torch_model: if "cuda" in map_location: @@ -4192,6 +4275,7 @@ def from_pretrained( compile_torch_model: Optional[bool] = False, quantize: Optional[str] = None, dtype: Optional[Union[str, torch.dtype]] = None, + low_cpu_mem_usage: bool = False, load_onnx_model: Optional[bool] = False, onnx_model_file: Optional[str] = "model.onnx", # Config overrides @@ -4228,6 +4312,10 @@ def from_pretrained( are cast during the state-dict read so the fp32 copy is never fully materialized; prefer this over ``quantize`` for plain precision changes. + low_cpu_mem_usage: If True, build the model under + ``torch.device("meta")`` and use ``load_state_dict(assign=True)``, + skipping the random-init compute and the fp32 shell allocation. + See the base-class docstring for the full contract. load_onnx_model: Whether to load ONNX model instead of PyTorch. onnx_model_file: Path to ONNX model file. max_length: Override max_length in config. @@ -4294,6 +4382,7 @@ def from_pretrained( compile_torch_model=compile_torch_model, quantize=quantize, dtype=dtype, + low_cpu_mem_usage=low_cpu_mem_usage, max_length=max_length, max_width=max_width, post_fusion_schema=post_fusion_schema, diff --git a/tests/test_quantize_and_dtype.py b/tests/test_quantize_and_dtype.py index bfccfb6..f7ae56c 100644 --- a/tests/test_quantize_and_dtype.py +++ b/tests/test_quantize_and_dtype.py @@ -214,3 +214,72 @@ def test_true_raises(self): # ``_FakeModel("cpu").quantize(True)`` which also raises. with pytest.raises(TypeError, match="expects a string"): _FakeModel("cpu").quantize(True) + + +class TestMaterializeMetaBuffers: + """``_materialize_meta_buffers`` post-fixes non-persistent buffers that + survive a meta-init + ``load_state_dict(assign=True)`` cycle.""" + + def _module_with_meta_position_ids(self, length: int = 16) -> nn.Module: + """Build a small module mirroring DeBERTa's ``position_ids`` buffer + registration, then move it to meta to simulate post-load state.""" + module = nn.Module() + module.register_buffer( + "position_ids", + torch.arange(0, length, dtype=torch.int64).unsqueeze(0), + persistent=False, + ) + module.position_ids = module.position_ids.to("meta") + return module + + def test_position_ids_restored_to_canonical_value(self): + m = self._module_with_meta_position_ids(length=8) + assert m.position_ids.is_meta # precondition + + materialized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == ["position_ids"] + assert not m.position_ids.is_meta + assert torch.equal( + m.position_ids, + torch.arange(0, 8, dtype=torch.int64).unsqueeze(0), + ) + + def test_no_op_when_no_meta_buffers(self): + m = nn.Module() + m.register_buffer("position_ids", torch.arange(0, 4).unsqueeze(0), persistent=False) + out = BaseGLiNER._materialize_meta_buffers(m) + assert out == [] + + def test_nested_module_meta_buffer_restored(self): + """Buffers nested inside child modules are walked and fixed too.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer( + "position_ids", + torch.arange(0, 4, dtype=torch.int64).unsqueeze(0), + persistent=False, + ) + inner.position_ids = inner.position_ids.to("meta") + outer.add_module("embeddings", inner) + + materialized = BaseGLiNER._materialize_meta_buffers(outer) + + assert materialized == ["embeddings.position_ids"] + assert not outer.embeddings.position_ids.is_meta + + def test_unknown_meta_buffer_warns_and_zero_fills(self): + """Buffers we don't know how to materialize fall back to zeros + warn.""" + m = nn.Module() + m.register_buffer( + "mystery_constant", + torch.tensor([1.0, 2.0, 3.0]), + persistent=False, + ) + m.mystery_constant = m.mystery_constant.to("meta") + + with pytest.warns(UserWarning, match="unknown non-persistent buffer"): + materialized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == ["mystery_constant"] + assert torch.equal(m.mystery_constant, torch.zeros(3)) From e9b63dc2033e9e354231c9776ac19837b0098e18 Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Mon, 27 Apr 2026 21:25:07 +0200 Subject: [PATCH 2/3] Auto-fallback for unrecognized non-persistent buffers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Validation across cached GLiNER architectures revealed two real gaps in the original meta-init path: 1. ``token_type_ids`` — BERT-family standard non-persistent buffer, used by the BGE labels encoder in BiEncoderSpanGLiNER. Canonical value is zeros — now handled in ``_materialize_meta_buffers``. 2. ``rotary_emb.inv_freq`` — RoPE inverse-frequency buffer in ModernBERT (and ettin-encoder, used by knowledgator/gliner-bi-base-v2.0). The canonical value is computed as ``1 / (base ** (arange(0, dim, 2) / dim))`` where ``base`` varies per-architecture (10000 for standard, 160000 for ModernBERT local attention) and isn't recoverable from the buffer alone. The previous "zero-fill + warn" behavior would have shipped silently broken inference (zeros break RoPE attention). Reworked ``_materialize_meta_buffers`` to return a ``(materialized, unrecognized)`` tuple. ``from_pretrained`` checks for unrecognized buffers and, if any exist, deletes the partial meta state and falls back to the standard load path with a single ``UserWarning`` naming the unsupported buffer pattern. Net effect: - DeBERTa-based architectures (UniEncoderSpan, UniEncoderToken): unchanged — full meta-init speedup. - BERT-family bi-encoder (BGE labels encoder): now also uses meta-init via the new ``token_type_ids`` handler. - RoPE-based bi-encoders (ModernBERT, ettin): auto-fall-back to the standard path with a clear warning. Bit-identical loaded params via the fallback. No risk of silently broken inference. Validation script ``benchmarks/low_cpu_mem_usage/arch_validation.py`` covers 6 cached models across 3 dispatcher classes: urchade/gliner_small-v2.1 UniEncoderSpanGLiNER OK (meta path) urchade/gliner_large-v2.1 UniEncoderSpanGLiNER OK (meta path) gliner-community/gliner_small-v2.5 UniEncoderSpanGLiNER OK (meta path) knowledgator/gliner-multitask-large-v0.5 UniEncoderTokenGLiNER OK (meta path) knowledgator/gliner-bi-base-v2.0 BiEncoderSpanGLiNER OK (auto-fallback: inv_freq) knowledgator/modern-gliner-bi-base-v1.0 BiEncoderSpanGLiNER OK (auto-fallback: inv_freq) All 6 produce parameters bit-identical (sha256 hash) to the standard load path. Inference predictions on a held-out sentence match across all DeBERTa models; bi-encoder inference is skipped because the baseline forward path is broken upstream (``BertModel.forward() got an unexpected keyword argument 'token_lengths'``) — the load itself succeeds for both baseline and lowmem. Test suite: 55 cases pass (54 -> 55 with the new ``test_token_type_ids_restored_to_zeros`` and the contract change to ``test_unknown_buffer_returned_as_unrecognized``). ruff lint and format clean. Performance unchanged — single-rep sanity at cuda_lowmem_bf16 is 1.91s vs cuda_baseline_bf16 2.80s, in line with the n=12 result of 1.61s vs 3.16s reported in the original commit. Co-Authored-By: Claude Opus 4.7 (1M context) --- gliner/model.py | 90 +++++++++++++++++++++----------- tests/test_quantize_and_dtype.py | 46 +++++++++++----- 2 files changed, 92 insertions(+), 44 deletions(-) diff --git a/gliner/model.py b/gliner/model.py index 7643025..7b02655 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -511,21 +511,29 @@ def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optio return cls._set_tokenizer_spec_tokens(tokenizer) @staticmethod - def _materialize_meta_buffers(module: nn.Module) -> list: + def _materialize_meta_buffers(module: nn.Module) -> tuple: """Re-materialize non-persistent buffers left on meta after assign-load. Walks the module tree and replaces any buffer still on the meta device after a meta-init + ``load_state_dict(assign=True)`` sequence. - Non-persistent buffers (registered with ``persistent=False``, e.g. - DeBERTa's ``position_ids``) are not in the saved state dict, so they - survive as meta tensors after the load. This helper restores them to - their canonical computed values. + Non-persistent buffers (registered with ``persistent=False``) are not + in the saved state dict, so they survive as meta tensors after the + load. This helper restores the ones we recognize: - Returns a list of buffer names that were materialized — useful for - tests and for warning on unknown patterns. + - ``position_ids`` — ``arange(0, max_pos).unsqueeze(0)`` (BERT/DeBERTa). + - ``token_type_ids`` — ``zeros((1, max_pos))`` (BERT-family default). + + For *unrecognized* buffers (e.g. RoPE's ``inv_freq``, where the + canonical value depends on a per-architecture ``base`` we can't + recover from the buffer alone), this returns them in + ``unrecognized`` so the caller can fall back to the standard load + path rather than ship a model with broken inference. + + Returns ``(materialized: list[str], unrecognized: list[str])``. """ - materialized = [] + materialized: list = [] + unrecognized: list = [] for name, buf in list(module.named_buffers()): if not buf.is_meta: continue @@ -538,21 +546,18 @@ def _materialize_meta_buffers(module: nn.Module) -> list: value = torch.arange(0, buf.shape[-1], dtype=buf.dtype).unsqueeze(0).contiguous() parent_mod.register_buffer(leaf, value, persistent=False) materialized.append(name) - else: - # Unknown non-persistent buffer — fall back to zeros and warn. - # If a caller hits this, the architecture has a buffer pattern - # we don't recognize; the user should either disable - # ``low_cpu_mem_usage`` or extend this method. - warnings.warn( - f"low_cpu_mem_usage materialized unknown non-persistent buffer " - f"{name!r} as zeros. Inference may be incorrect for this module. " - f"Pass low_cpu_mem_usage=False to disable the meta-init path.", - UserWarning, - stacklevel=3, - ) - parent_mod.register_buffer(leaf, torch.zeros(buf.shape, dtype=buf.dtype), persistent=False) + elif leaf == "token_type_ids": + # BERT-family default: zeros, broadcast to (1, max_pos). + value = torch.zeros(buf.shape, dtype=buf.dtype) + parent_mod.register_buffer(leaf, value, persistent=False) materialized.append(name) - return materialized + else: + # Unrecognized non-persistent buffer. We can't safely zero-fill + # because the canonical value may be load-bearing (e.g. RoPE + # ``inv_freq`` is computed from ``base ** (arange(0, dim, 2) / dim)`` + # and zeros would break attention). Surface to caller for fallback. + unrecognized.append(name) + return materialized, unrecognized @classmethod def _load_state_dict( @@ -878,35 +883,58 @@ def from_pretrained( torch_dtype = cls._parse_dtype(dtype) + instance = None if low_cpu_mem_usage: # Build the model graph on meta device — no real allocation, # no random-init compute. Shape descriptors only. with torch.device("meta"): - instance = cls( + meta_instance = cls( config, tokenizer=tokenizer, backbone_from_pretrained=False, cache_dir=cache_dir, **model_kwargs, ) - cls._resize_token_embeddings(instance, config, tokenizer, resize_token_embeddings) + cls._resize_token_embeddings(meta_instance, config, tokenizer, resize_token_embeddings) # Read state dict (cast on read if dtype was set), then swap # tensors directly into the meta-shell parameter slots. state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) - instance.model.load_state_dict(state_dict, assign=True, strict=strict) + meta_instance.model.load_state_dict(state_dict, assign=True, strict=strict) del state_dict # Materialize non-persistent buffers (position_ids etc.) that # the state dict didn't carry. - cls._materialize_meta_buffers(instance.model) + _materialized, unrecognized = cls._materialize_meta_buffers(meta_instance.model) + + if unrecognized: + # Buffers we don't know how to recompute (e.g. RoPE inv_freq + # whose base varies per-architecture). The meta-init load + # would produce wrong inference; fall back to the standard + # path so the user gets a correct model. Cost is one full + # standard load; benefit is no silent correctness bug. + short_names = sorted({n.rsplit(".", 1)[-1] for n in unrecognized}) + warnings.warn( + f"low_cpu_mem_usage=True is not supported for this architecture: " + f"the model has non-persistent buffer(s) {short_names} that " + f"_materialize_meta_buffers does not recognize " + f"(e.g. RoPE inv_freq for ModernBERT). Falling back to the " + f"standard load path so inference is correct. Pass " + f"low_cpu_mem_usage=False to silence this warning.", + UserWarning, + stacklevel=2, + ) + del meta_instance + else: + # All buffers recognized — meta path succeeded. + meta_instance.model.to(map_location) + instance = meta_instance - # If the state dict's map_location was "cpu" but the caller - # asked for cuda (or vice versa), move now. assign=True keeps - # tensors on whatever device they were loaded to. - instance.model.to(map_location) - else: + if instance is None: # Standard path: random-init shell at fp32, optional cast, load. + # Reached when low_cpu_mem_usage=False, OR when the meta-init + # path detected unrecognized non-persistent buffers and fell + # back automatically. instance = cls( config, tokenizer=tokenizer, diff --git a/tests/test_quantize_and_dtype.py b/tests/test_quantize_and_dtype.py index f7ae56c..f2f8380 100644 --- a/tests/test_quantize_and_dtype.py +++ b/tests/test_quantize_and_dtype.py @@ -236,9 +236,10 @@ def test_position_ids_restored_to_canonical_value(self): m = self._module_with_meta_position_ids(length=8) assert m.position_ids.is_meta # precondition - materialized = BaseGLiNER._materialize_meta_buffers(m) + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) assert materialized == ["position_ids"] + assert unrecognized == [] assert not m.position_ids.is_meta assert torch.equal( m.position_ids, @@ -248,8 +249,9 @@ def test_position_ids_restored_to_canonical_value(self): def test_no_op_when_no_meta_buffers(self): m = nn.Module() m.register_buffer("position_ids", torch.arange(0, 4).unsqueeze(0), persistent=False) - out = BaseGLiNER._materialize_meta_buffers(m) - assert out == [] + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + assert materialized == [] + assert unrecognized == [] def test_nested_module_meta_buffer_restored(self): """Buffers nested inside child modules are walked and fixed too.""" @@ -263,23 +265,41 @@ def test_nested_module_meta_buffer_restored(self): inner.position_ids = inner.position_ids.to("meta") outer.add_module("embeddings", inner) - materialized = BaseGLiNER._materialize_meta_buffers(outer) + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(outer) assert materialized == ["embeddings.position_ids"] + assert unrecognized == [] assert not outer.embeddings.position_ids.is_meta - def test_unknown_meta_buffer_warns_and_zero_fills(self): - """Buffers we don't know how to materialize fall back to zeros + warn.""" + def test_token_type_ids_restored_to_zeros(self): + """BERT-family ``token_type_ids`` is a non-persistent buffer of zeros.""" m = nn.Module() m.register_buffer( - "mystery_constant", - torch.tensor([1.0, 2.0, 3.0]), + "token_type_ids", + torch.zeros((1, 6), dtype=torch.int64), persistent=False, ) - m.mystery_constant = m.mystery_constant.to("meta") + m.token_type_ids = m.token_type_ids.to("meta") - with pytest.warns(UserWarning, match="unknown non-persistent buffer"): - materialized = BaseGLiNER._materialize_meta_buffers(m) + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) - assert materialized == ["mystery_constant"] - assert torch.equal(m.mystery_constant, torch.zeros(3)) + assert materialized == ["token_type_ids"] + assert unrecognized == [] + assert torch.equal(m.token_type_ids, torch.zeros((1, 6), dtype=torch.int64)) + + def test_unknown_buffer_returned_as_unrecognized(self): + """Unrecognized buffers are surfaced for caller-side fallback (not zero-filled).""" + m = nn.Module() + m.register_buffer( + "inv_freq", + torch.tensor([1.0, 0.5, 0.25]), + persistent=False, + ) + m.inv_freq = m.inv_freq.to("meta") + + materialized, unrecognized = BaseGLiNER._materialize_meta_buffers(m) + + assert materialized == [] + assert unrecognized == ["inv_freq"] + # The buffer is still meta — caller must fall back to the standard load path. + assert m.inv_freq.is_meta From 6513398d990fe519c7eb91fe0b9e6bb97208ce89 Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Mon, 27 Apr 2026 23:43:09 +0200 Subject: [PATCH 3/3] Fall back when assign-load leaves meta parameters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Codex review finding [P2]: with ``low_cpu_mem_usage=True`` and the default ``strict=False``, ``load_state_dict(assign=True)`` succeeds even if the checkpoint is missing a parameter — but that parameter stays on the meta device. The subsequent ``instance.model.to(map_location)`` then raises ``NotImplementedError: Cannot copy out of meta tensor``. The standard load path would have kept the random-initialized value and loaded successfully, so the meta path is a strict regression for this case. Fix: after ``load_state_dict(assign=True)``, scan ``named_parameters()`` for any tensor still on meta. If any are found (or the existing unrecognized-buffer check fired), discard the partial meta state and fall back to the standard load path. The fallback warning now names the cause — either a list of unrecognized non-persistent buffers (e.g. RoPE ``inv_freq``) or a sample of the missing parameter names — truncated for readability when the missing-key set is large. End-to-end verification on a synthetic ``urchade/gliner_medium-v2.1`` clone with ``span_rep_layer.span_rep_layer.out_project.0.bias`` removed from ``model.safetensors``: - ``low_cpu_mem_usage=True``: load succeeds via the fallback, user-visible UserWarning names the missing key, ``param dtype`` is correct (bfloat16). Pre-fix this would have raised ``NotImplementedError`` from ``.to()``. Adds ``TestMetaParamFallbackContract`` (3 cases) asserting the underlying contract: ``load_state_dict(assign=True, strict=False)`` leaves params on meta when keys are missing, the post-assign scan finds them, and ``.to()`` on a remaining meta param raises. 58 unit tests pass total (was 55). Co-Authored-By: Claude Opus 4.7 (1M context) --- gliner/model.py | 58 +++++++++++++++++++++++--------- tests/test_quantize_and_dtype.py | 55 ++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 15 deletions(-) diff --git a/gliner/model.py b/gliner/model.py index 7b02655..b619c88 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -900,33 +900,61 @@ def from_pretrained( # Read state dict (cast on read if dtype was set), then swap # tensors directly into the meta-shell parameter slots. state_dict = cls._load_state_dict(model_file, map_location, dtype=torch_dtype) - meta_instance.model.load_state_dict(state_dict, assign=True, strict=strict) + incompat = meta_instance.model.load_state_dict(state_dict, assign=True, strict=strict) del state_dict # Materialize non-persistent buffers (position_ids etc.) that # the state dict didn't carry. _materialized, unrecognized = cls._materialize_meta_buffers(meta_instance.model) - if unrecognized: - # Buffers we don't know how to recompute (e.g. RoPE inv_freq - # whose base varies per-architecture). The meta-init load - # would produce wrong inference; fall back to the standard - # path so the user gets a correct model. Cost is one full - # standard load; benefit is no silent correctness bug. - short_names = sorted({n.rsplit(".", 1)[-1] for n in unrecognized}) + # Detect parameters left on meta. With strict=True, missing + # keys would have raised at load_state_dict; with strict=False + # (the default) they don't, leaving the parameter on the meta + # device. The standard path keeps random-init values for these + # params, which is what the caller would have seen without + # low_cpu_mem_usage=True. The subsequent .to(map_location) + # would otherwise raise ``NotImplementedError: Cannot copy out + # of meta tensor`` and fail the load entirely. + meta_param_names = [n for n, p in meta_instance.model.named_parameters() if p.is_meta] + + if unrecognized or meta_param_names: + # Cases that meta-init can't safely handle: unrecognized + # non-persistent buffers (e.g. RoPE inv_freq whose base + # varies per-architecture), or parameters that the state + # dict didn't supply (strict=False + missing keys). Fall + # back to the standard load path — cost is one full + # standard load; benefit is no silent correctness bug + # and no spurious crash on .to(map_location). + if unrecognized: + short_names = sorted({n.rsplit(".", 1)[-1] for n in unrecognized}) + reason = ( + f"the model has non-persistent buffer(s) {short_names} that " + f"_materialize_meta_buffers does not recognize " + f"(e.g. RoPE inv_freq for ModernBERT)" + ) + else: + # Truncate to keep the warning readable; the missing-key + # set can be large for genuinely incomplete checkpoints. + sample = sorted(meta_param_names)[:5] + more = f" (and {len(meta_param_names) - 5} more)" if len(meta_param_names) > 5 else "" + reason = ( + f"the checkpoint is missing parameter(s) {sample}{more}; " + f"the standard load path would have kept the random-init " + f"values for these" + ) warnings.warn( - f"low_cpu_mem_usage=True is not supported for this architecture: " - f"the model has non-persistent buffer(s) {short_names} that " - f"_materialize_meta_buffers does not recognize " - f"(e.g. RoPE inv_freq for ModernBERT). Falling back to the " - f"standard load path so inference is correct. Pass " - f"low_cpu_mem_usage=False to silence this warning.", + f"low_cpu_mem_usage=True is not supported for this load: " + f"{reason}. Falling back to the standard load path so " + f"inference is correct. Pass low_cpu_mem_usage=False to " + f"silence this warning.", UserWarning, stacklevel=2, ) del meta_instance else: - # All buffers recognized — meta path succeeded. + # All buffers recognized and all params materialized — + # meta path succeeded. + del incompat meta_instance.model.to(map_location) instance = meta_instance diff --git a/tests/test_quantize_and_dtype.py b/tests/test_quantize_and_dtype.py index f2f8380..1e151ee 100644 --- a/tests/test_quantize_and_dtype.py +++ b/tests/test_quantize_and_dtype.py @@ -303,3 +303,58 @@ def test_unknown_buffer_returned_as_unrecognized(self): assert unrecognized == ["inv_freq"] # The buffer is still meta — caller must fall back to the standard load path. assert m.inv_freq.is_meta + + +class TestMetaParamFallbackContract: + """Codex review finding: with ``low_cpu_mem_usage=True`` and the default + ``strict=False``, ``load_state_dict(assign=True)`` may leave a parameter + on the meta device when the checkpoint is missing a key. The subsequent + ``.to(map_location)`` then raises ``NotImplementedError: Cannot copy out + of meta tensor``, whereas the standard path would have kept the + random-init value and loaded successfully. + + These tests assert the *contract* underlying the fix without spinning up + a real GLiNER model: ``load_state_dict(assign=True, strict=False)`` does + leave parameters on meta when keys are missing, and the + "scan for meta parameters after assign" check in ``from_pretrained`` + correctly identifies them. + """ + + def test_assign_load_with_missing_key_leaves_param_on_meta(self): + """The premise — without our scan, ``.to()`` would crash.""" + with torch.device("meta"): + module = nn.Linear(4, 4) + # State dict is missing the bias. + partial_sd = {"weight": torch.randn(4, 4)} + result = module.load_state_dict(partial_sd, assign=True, strict=False) + assert "bias" in result.missing_keys + # The bug: bias is still on meta. + assert module.bias.is_meta + # And .to() on a meta param raises. + with pytest.raises(NotImplementedError, match="meta"): + module.to("cpu") + + def test_meta_param_scan_finds_missing_assign_targets(self): + """The fix's scan: walk named_parameters looking for ``is_meta``, + report names so the fallback warning can surface them.""" + with torch.device("meta"): + module = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 2)) + # Provide weights but not biases. + partial_sd = { + "0.weight": torch.randn(4, 4), + "1.weight": torch.randn(2, 4), + } + module.load_state_dict(partial_sd, assign=True, strict=False) + meta_params = [n for n, p in module.named_parameters() if p.is_meta] + assert sorted(meta_params) == ["0.bias", "1.bias"] + + def test_full_assign_load_leaves_no_meta_params(self): + """Sanity: when the state dict is complete, no meta params remain.""" + with torch.device("meta"): + module = nn.Linear(4, 4) + full_sd = {"weight": torch.randn(4, 4), "bias": torch.randn(4)} + module.load_state_dict(full_sd, assign=True, strict=False) + meta_params = [n for n, p in module.named_parameters() if p.is_meta] + assert meta_params == [] + # And .to() succeeds on the materialized module. + module.to("cpu")