Add quantize/compile support for ~1.9x GPU speedup#342
Merged
Conversation
…ality loss) Adds a `quantize=True` parameter to `from_pretrained()` and a `model.quantize()` method. On GPU, converts to fp16 half-precision for Tensor Core acceleration. On CPU, applies fp16 dynamic quantization (memory savings only, warns user). Benchmarked on CoNLL-2003 and WNUT-2017 with gliner_medium-v2.1 on RTX 5090. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ed speedup - compile() now uses dynamic=True (shape-generic kernels for variable-length NER inputs) and capture_scalar_outputs (traces through data-dependent shape ops in extract_prompt_features). - Combined quantize + compile gives 1.94x speedup over fp32 baseline. - Clean up docstrings: remove int8 references, document compile+quantize combo. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…WSL only Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… device - Remove stray print(valid_input_spans) on inference hot path (model.py:2277) - SpanCAT: use .expand() instead of .repeat() to avoid intermediate tensor allocation (~65MB at typical batch sizes) - SpanEndpointsBlock: build span indices with torch.arange on device instead of Python list comprehension + .to(device) transfer Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
quantize=True still works (backward compat → fp16). New options: - "bf16": bfloat16 with better numerical stability - "int8": int8 dynamic quantization (CPU only, non-DeBERTa) int8 on DeBERTa raises ValueError — quantization error accumulates across 12 transformer layers and collapses output scores to near-zero. Verified by isolating individual layers: each survives int8, but the full encoder stack degrades from F1=0.81 to F1=0.0. Updated README.md and docs/usage.md with new options and corrected benchmark table. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
int8 dynamic quantization is broken on all current GLiNER models (DeBERTa-based): error accumulates across 12 transformer layers and collapses F1 to 0.0. No point exposing a dead option. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
CPU: PyTorch built-in quantize_dynamic with FBGEMM int8 kernels (~1.6x speedup, no extra dependencies). GPU: torchao Int8WeightOnlyConfig (~50% memory reduction, no speed gain; torchao required). INT8 Tensor Cores cannot be used because GLiNER's internal dimensions fall below the cuBLAS _int_mm minimum (M>=16). Stock DeBERTa models lose accuracy with int8 — this is intended for models fine-tuned with quantization-aware training (QAT). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This was referenced Apr 23, 2026
maxwbuckley
added a commit
to maxwbuckley/GLiNER
that referenced
this pull request
Apr 23, 2026
Two related changes to the precision / quantization surface, landed
together because they form one coherent story.
1. Add `dtype=` to `GLiNER.from_pretrained`
Load weights directly at the target floating-point precision. Each
state-dict tensor is cast during the `safe_open` read and the
random-init model shell is pre-cast via `instance.model.to(dtype)`
before `load_state_dict`, so a full fp32 snapshot never co-exists
with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`,
`"float32"`, ...) or a floating-point `torch.dtype`; non-floating
dtypes (e.g. `torch.int8`) are rejected up front with a message
pointing at `quantize="int8"` for int paths. Int / bool buffers are
preserved in the state dict.
Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x
fp32; for `map_location="cuda"`, the saving is avoiding a
simultaneous fp32 GPU state dict + fp32 GPU model plus the separate
post-load cast pass. Matches the `dtype=` surface on
HuggingFace `transformers.PreTrainedModel.from_pretrained` (string
or `torch.dtype`, same semantics), so users coming from HF get a
familiar API.
Primary target: cold starts and scalable serverless deployments
(Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where
startup latency and peak memory drive cost and SLA.
2. Deprecate the pure-downcast paths in `quantize(...)`
With `dtype=` in place, several `quantize=` values are just
`.to(dtype)` with an extra fp32 intermediate:
- `quantize="fp16"` / `True` on GPU -> `model.half()` (pure downcast)
- `quantize="bf16"` on GPU or CPU -> `model.bfloat16()` (pure downcast)
- `quantize="fp16"` on CPU -> dynamic quant of nn.Linear (real)
- `quantize="int8"` on GPU -> torchao weight-only (real)
- `quantize="int8"` on CPU -> FBGEMM dynamic quant (real)
The three pure-downcast rows now emit a `DeprecationWarning` from
`model.quantize(...)` pointing at `dtype=` / `model.to(...)`.
Behavior is preserved; removal is a future PR. CPU fp16 dynamic
quantization and both int8 paths stay silent. When both `dtype=`
and a downcast `quantize=` are passed to `from_pretrained`, a
separate warning fires because `quantize` runs after the load and
overwrites the precision.
Docs in docs/usage.md are rewritten: new "Reduced-precision loading
(`dtype`)" section; "Quantization, Compilation & FlashDeBERTa" now
shows `dtype="fp16"` as the recommended half-precision path and
includes a `quantize=` vs `dtype=` taxonomy explaining the CPU-fp16
exception.
Current PyPI release (`gliner 0.2.26`, 2026-03-19) pre-dates PR urchade#342,
so no released wheel ships the `quantize=` surface; the deprecation
only reaches users on `main` / a git SHA.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
quantizeparameter onfrom_pretrained()/from_config()andmodel.quantize()— acceptsTrue/"fp16"(float16) or"bf16"(bfloat16)compile()fixed — now usesdynamic=True(shape-generic kernels for variable-length NER inputs) andcapture_scalar_outputs=True(traces through data-dependentmax_embed_dimcomputation)torch.compile+ FlashDeBERTa fix — wraps FlashDeBERTa's Triton kernels withtorch.compiler.disableto prevent dynamo tracing failuresB*8to~8total, eliminating the per-item overhead regression from PR Vectorize CPU-path preprocessing and decoding hot loops #333print(valid_input_spans)debug statementSpanCAT.forward():.expand()over.repeat()(avoids ~65MB intermediate allocation)SpanEndpointsBlock.forward():torch.arangeon device instead of CPU list +.to(device)Quantization options
quantize=Trueis backward compatible (maps to fp16). On CPU, reduces memory but does not improve speed.Benchmark methodology
urchade/gliner_medium-v2.1(DeBERTa-v3 backbone, 195M params)torch.cuda.synchronize()Quality — zero degradation
All conditions produce identical F1 on both benchmarks:
Speed — up to 1.94x
Benchmarked on CoNLL-2003, 30 docs, 30 iterations:
Why
torch.compilewas broken before this PRThe previous
compile()calledtorch.compile(model)with default settings (static shape specialization), making it slower than not compiling at all (~0.96x):extract_prompt_features—max_embed_dim = num_class_tokens.max()extracts a data-dependent scalar, causing Dynamo to split the graphFixed by
torch.compile(model, dynamic=True)+capture_scalar_outputs = True.Other approaches evaluated
Test plan
quantize=Trueon GPU: identical predictions, 1.35x speedupquantize="bf16"on GPU: identical predictions, 1.22x speedupquantize=Trueon CPU: identical predictions, emits UserWarning about no speed benefitcompile_torch_model=Trueon GPU: identical predictions, 1.31x speedup (with dynamic=True fix)quantize=Truestill worksFuzz & edge-case testing (118 tests, all passing)
Comprehensive fuzz testing across the full API surface (
tests/test_fuzz.py):Unit-level (68 tests, no model download needed)
Tokenizer — WhitespaceTokenSplitter (23 tests)
" ","\t","\n","\r\n")Overlap utilities (14 tests)
has_overlapping: identical, adjacent, partial, nested, single-token, touching spans; multi_label modehas_overlapping_nested: containment allowed, partial rejected, identical with/without multi_labelis_nested: both directions, identical, partial overlap, no overlapGreedy search (12 tests)
SpanDecoder (19 tests)
id_to_classes,return_class_probsoutput validationinf/-inf/NaNlogits (no crash)Integration-level (50 tests, end-to-end with
gliner_small-v2.5)Empty/whitespace inputs (4 tests): empty string, whitespace-only, batch of all empties, mixed empty+valid batch
Minimal inputs (3 tests): single word, single character, single punctuation mark
Label edge cases (7 tests): empty label list, single label, 50 labels (>max_types=25), duplicate labels (verified deduplication), special characters in labels (
person/human,city (location)), unicode labels (人物/場所/組織), empty/whitespace-only labelsThreshold boundaries (4 tests): 0.0 (returns many), 1e-10, 0.999, 1.0 (returns none)
NER mode combinations (4 tests): all 4 combinations of
flat_ner×multi_label; verifiedflat_ner=Trueproduces zero overlapping spansUnicode text (4 tests): CJK, emoji, Arabic, mixed scripts in single text
Special text patterns (6 tests): only punctuation, only numbers, HTML tags, URLs, newlines, tabs
Batch processing (5 tests):
batch_size=1, batch_size > input count, 50-item batch, variable-length batch, string input auto-wrapped to listOutput invariant validation (4 tests):
entity["text"] == text[entity["start"]:entity["end"]]for every entity[0.0, 1.0]0 <= start < end <= len(text)for every entityDeterminism (1 test): identical results on repeated calls
Stress (2 tests): 1000-char single "word", 500 single-char tokens
input_spans(3 tests): basic span filtering, empty spans → empty result, misaligned spans silently dropped🤖 Generated with Claude Code