Skip to content

Add quantize/compile support for ~1.9x GPU speedup#342

Merged
Ingvarstep merged 12 commits into
urchade:mainfrom
maxwbuckley:add-quantize-flag
Mar 31, 2026
Merged

Add quantize/compile support for ~1.9x GPU speedup#342
Ingvarstep merged 12 commits into
urchade:mainfrom
maxwbuckley:add-quantize-flag

Conversation

@maxwbuckley

@maxwbuckley maxwbuckley commented Mar 20, 2026

Copy link
Copy Markdown
Contributor

Summary

  • quantize parameter on from_pretrained() / from_config() and model.quantize() — accepts True/"fp16" (float16) or "bf16" (bfloat16)
  • compile() fixed — now uses dynamic=True (shape-generic kernels for variable-length NER inputs) and capture_scalar_outputs=True (traces through data-dependent max_embed_dim computation)
  • torch.compile + FlashDeBERTa fix — wraps FlashDeBERTa's Triton kernels with torch.compiler.disable to prevent dynamo tracing failures
  • Batch-level span decoding — reduces CUDA kernel launches from B*8 to ~8 total, eliminating the per-item overhead regression from PR Vectorize CPU-path preprocessing and decoding hot loops #333
  • Minor cleanups:
    • Removed stray print(valid_input_spans) debug statement
    • SpanCAT.forward(): .expand() over .repeat() (avoids ~65MB intermediate allocation)
    • SpanEndpointsBlock.forward(): torch.arange on device instead of CPU list + .to(device)

Quantization options

model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", map_location="cuda")
model.quantize()         # fp16 (default) — ~1.35x GPU speedup
model.quantize("bf16")   # bfloat16 — better numerical stability, ~1.2x

quantize=True is backward compatible (maps to fp16). On CPU, reduces memory but does not improve speed.

Benchmark methodology

  • Model: urchade/gliner_medium-v2.1 (DeBERTa-v3 backbone, 195M params)
  • Datasets: CoNLL-2003 test (200 docs, 372 entities), WNUT-2017 test (200 docs, 163 entities)
  • Quality metric: nervaluate strict matching (exact boundary + type)
  • Speed: 30 iterations on 30-doc batch, 10 warmup, torch.cuda.synchronize()
  • Hardware: NVIDIA RTX 5090, PyTorch 2.8.0+cu128, CUDA 12.8
  • Statistical test: Welch's t-test

Quality — zero degradation

All conditions produce identical F1 on both benchmarks:

Condition CoNLL-2003 WNUT-2017
P R F1 P R F1
fp32 (baseline) 0.8097 0.8118 0.8107 0.5242 0.7727 0.6247
quantize (fp16) 0.8097 0.8118 0.8107 0.5242 0.7727 0.6247
quantize + compile 0.8097 0.8118 0.8107 0.5242 0.7727 0.6247

Speed — up to 1.94x

Benchmarked on CoNLL-2003, 30 docs, 30 iterations:

Condition Median Docs/sec Speedup p-value
fp32 (baseline) 0.0451s 665 1.00x
+ quantize 0.0334s 899 1.35x p=2.8e-05 ***
+ compile(dynamic=True) 0.0345s 870 1.31x p=9.6e-08 ***
+ quantize + compile 0.0233s 1287 1.94x p=3.7e-20 ***

Why torch.compile was broken before this PR

The previous compile() called torch.compile(model) with default settings (static shape specialization), making it slower than not compiling at all (~0.96x):

  1. Static shape recompilation — GLiNER processes variable-length text inputs, so every new sequence length triggered a full recompilation
  2. Graph break in extract_prompt_featuresmax_embed_dim = num_class_tokens.max() extracts a data-dependent scalar, causing Dynamo to split the graph

Fixed by torch.compile(model, dynamic=True) + capture_scalar_outputs = True.

Other approaches evaluated

Approach F1 Speedup Verdict
torch.compile (static, old behavior) 0.8107 ~0.96x Slower — recompiles on every new shape
torch.compile (dynamic=True) 0.8107 1.31x Fixed — shape-generic kernels
int8 dynamic quantization (CPU) 0.0000 Breaks DeBERTa (error accumulates across 12 layers)
bitsandbytes int8 (GPU) 0.8200 0.98x Slower than fp32 on modern GPUs
torchao FP8 weight-only 0.8097 0.98x No speedup, software overhead
torchao FP8 dynamic act+weight 0.8054 0.03x 35x slower (quantize/dequant overhead)
bfloat16 0.8134 1.22x Works, slightly less speedup than fp16

Test plan

  • quantize=True on GPU: identical predictions, 1.35x speedup
  • quantize="bf16" on GPU: identical predictions, 1.22x speedup
  • quantize=True on CPU: identical predictions, emits UserWarning about no speed benefit
  • compile_torch_model=True on GPU: identical predictions, 1.31x speedup (with dynamic=True fix)
  • Combined quantize + compile: identical predictions, 1.94x speedup
  • FlashDeBERTa + compile: no crash (torch.compiler.disable wrapper)
  • ONNX model raises RuntimeError on quantize()
  • Quality benchmarked on CoNLL-2003 and WNUT-2017 (nervaluate strict F1)
  • Speed benchmarked with n=30, warmup=10, Welch's t-test
  • Backward compatible: quantize=True still works
  • 118 fuzz/edge-case tests — all passing (see below)

Fuzz & edge-case testing (118 tests, all passing)

Comprehensive fuzz testing across the full API surface (tests/test_fuzz.py):

Unit-level (68 tests, no model download needed)

Tokenizer — WhitespaceTokenSplitter (23 tests)

  • Empty string, whitespace-only (" ", "\t", "\n", "\r\n")
  • Single characters: letter, digit, punctuation
  • Unicode: CJK (中文), emoji (🌍🎉), combining characters (e+\u0301), zero-width chars (\u200b\u200c), Arabic (العربية), mixed scripts
  • Special patterns: hyphenated words, underscored words, multiple punctuation, null bytes, multi-codepoint emoji (👨‍👩‍👧‍👦)
  • Stress: 10,000 tokens, repeated separators, leading/trailing whitespace

Overlap utilities (14 tests)

  • has_overlapping: identical, adjacent, partial, nested, single-token, touching spans; multi_label mode
  • has_overlapping_nested: containment allowed, partial rejected, identical with/without multi_label
  • is_nested: both directions, identical, partial overlap, no overlap

Greedy search (12 tests)

  • Empty input, single span, non-overlapping preserved, higher score wins, tie-breaking determinism
  • Multi-label same-span allowed, different-span still filtered
  • Nested NER allows containment, result sorted by start position
  • 20 overlapping spans stress test, zero/negative scores

SpanDecoder (19 tests)

  • All-zero logits (sigmoid=0.5 boundary), threshold 0.0 / 1.0 extremes
  • Single-token text, empty token list
  • Batch-size-1 fast path produces identical results to batch path (verified element-by-element)
  • Per-sample id_to_classes, return_class_probs output validation
  • inf / -inf / NaN logits (no crash)
  • Span boundary validation (no span exceeds token count)
  • 100 classes, batch of 64, variable-length sequences across batch

Integration-level (50 tests, end-to-end with gliner_small-v2.5)

Empty/whitespace inputs (4 tests): empty string, whitespace-only, batch of all empties, mixed empty+valid batch

Minimal inputs (3 tests): single word, single character, single punctuation mark

Label edge cases (7 tests): empty label list, single label, 50 labels (>max_types=25), duplicate labels (verified deduplication), special characters in labels (person/human, city (location)), unicode labels (人物/場所/組織), empty/whitespace-only labels

Threshold boundaries (4 tests): 0.0 (returns many), 1e-10, 0.999, 1.0 (returns none)

NER mode combinations (4 tests): all 4 combinations of flat_ner × multi_label; verified flat_ner=True produces zero overlapping spans

Unicode text (4 tests): CJK, emoji, Arabic, mixed scripts in single text

Special text patterns (6 tests): only punctuation, only numbers, HTML tags, URLs, newlines, tabs

Batch processing (5 tests): batch_size=1, batch_size > input count, 50-item batch, variable-length batch, string input auto-wrapped to list

Output invariant validation (4 tests):

  • entity["text"] == text[entity["start"]:entity["end"]] for every entity
  • All scores in [0.0, 1.0]
  • 0 <= start < end <= len(text) for every entity
  • All returned labels are from the input label set

Determinism (1 test): identical results on repeated calls

Stress (2 tests): 1000-char single "word", 500 single-char tokens

input_spans (3 tests): basic span filtering, empty spans → empty result, misaligned spans silently dropped

🤖 Generated with Claude Code

maxwbuckley and others added 2 commits March 20, 2026 21:53
…ality loss)

Adds a `quantize=True` parameter to `from_pretrained()` and a `model.quantize()`
method. On GPU, converts to fp16 half-precision for Tensor Core acceleration.
On CPU, applies fp16 dynamic quantization (memory savings only, warns user).

Benchmarked on CoNLL-2003 and WNUT-2017 with gliner_medium-v2.1 on RTX 5090.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ed speedup

- compile() now uses dynamic=True (shape-generic kernels for variable-length
  NER inputs) and capture_scalar_outputs (traces through data-dependent shape
  ops in extract_prompt_features).
- Combined quantize + compile gives 1.94x speedup over fp32 baseline.
- Clean up docstrings: remove int8 references, document compile+quantize combo.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@maxwbuckley maxwbuckley changed the title Add quantize flag for easy fp16 inference Mar 20, 2026
maxwbuckley and others added 7 commits March 20, 2026 22:26
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…WSL only

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… device

- Remove stray print(valid_input_spans) on inference hot path (model.py:2277)
- SpanCAT: use .expand() instead of .repeat() to avoid intermediate tensor
  allocation (~65MB at typical batch sizes)
- SpanEndpointsBlock: build span indices with torch.arange on device instead
  of Python list comprehension + .to(device) transfer

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
quantize=True still works (backward compat → fp16). New options:
- "bf16": bfloat16 with better numerical stability
- "int8": int8 dynamic quantization (CPU only, non-DeBERTa)

int8 on DeBERTa raises ValueError — quantization error accumulates
across 12 transformer layers and collapses output scores to near-zero.
Verified by isolating individual layers: each survives int8, but the
full encoder stack degrades from F1=0.81 to F1=0.0.

Updated README.md and docs/usage.md with new options and corrected
benchmark table.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@maxwbuckley maxwbuckley changed the title Add quantize flag and fix torch.compile for ~1.9x GPU speedup Mar 22, 2026
maxwbuckley and others added 2 commits March 22, 2026 20:40
int8 dynamic quantization is broken on all current GLiNER models
(DeBERTa-based): error accumulates across 12 transformer layers and
collapses F1 to 0.0. No point exposing a dead option.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
CPU: PyTorch built-in quantize_dynamic with FBGEMM int8 kernels (~1.6x
speedup, no extra dependencies).

GPU: torchao Int8WeightOnlyConfig (~50% memory reduction, no speed gain;
torchao required). INT8 Tensor Cores cannot be used because GLiNER's
internal dimensions fall below the cuBLAS _int_mm minimum (M>=16).

Stock DeBERTa models lose accuracy with int8 — this is intended for
models fine-tuned with quantization-aware training (QAT).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@urchade urchade requested a review from Ingvarstep March 25, 2026 15:03
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Ingvarstep Ingvarstep merged commit 4eefb8f into urchade:main Mar 31, 2026
maxwbuckley added a commit to maxwbuckley/GLiNER that referenced this pull request Apr 23, 2026
Two related changes to the precision / quantization surface, landed
together because they form one coherent story.

1. Add `dtype=` to `GLiNER.from_pretrained`

   Load weights directly at the target floating-point precision. Each
   state-dict tensor is cast during the `safe_open` read and the
   random-init model shell is pre-cast via `instance.model.to(dtype)`
   before `load_state_dict`, so a full fp32 snapshot never co-exists
   with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`,
   `"float32"`, ...) or a floating-point `torch.dtype`; non-floating
   dtypes (e.g. `torch.int8`) are rejected up front with a message
   pointing at `quantize="int8"` for int paths. Int / bool buffers are
   preserved in the state dict.

   Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x
   fp32; for `map_location="cuda"`, the saving is avoiding a
   simultaneous fp32 GPU state dict + fp32 GPU model plus the separate
   post-load cast pass. Matches the `dtype=` surface on
   HuggingFace `transformers.PreTrainedModel.from_pretrained` (string
   or `torch.dtype`, same semantics), so users coming from HF get a
   familiar API.

   Primary target: cold starts and scalable serverless deployments
   (Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where
   startup latency and peak memory drive cost and SLA.

2. Deprecate the pure-downcast paths in `quantize(...)`

   With `dtype=` in place, several `quantize=` values are just
   `.to(dtype)` with an extra fp32 intermediate:

     - `quantize="fp16"` / `True` on GPU  -> `model.half()`      (pure downcast)
     - `quantize="bf16"` on GPU or CPU   -> `model.bfloat16()`  (pure downcast)
     - `quantize="fp16"` on CPU          -> dynamic quant of nn.Linear (real)
     - `quantize="int8"` on GPU          -> torchao weight-only (real)
     - `quantize="int8"` on CPU          -> FBGEMM dynamic quant (real)

   The three pure-downcast rows now emit a `DeprecationWarning` from
   `model.quantize(...)` pointing at `dtype=` / `model.to(...)`.
   Behavior is preserved; removal is a future PR. CPU fp16 dynamic
   quantization and both int8 paths stay silent. When both `dtype=`
   and a downcast `quantize=` are passed to `from_pretrained`, a
   separate warning fires because `quantize` runs after the load and
   overwrites the precision.

Docs in docs/usage.md are rewritten: new "Reduced-precision loading
(`dtype`)" section; "Quantization, Compilation & FlashDeBERTa" now
shows `dtype="fp16"` as the recommended half-precision path and
includes a `quantize=` vs `dtype=` taxonomy explaining the CPU-fp16
exception.

Current PyPI release (`gliner 0.2.26`, 2026-03-19) pre-dates PR urchade#342,
so no released wheel ships the `quantize=` surface; the deprecation
only reaches users on `main` / a git SHA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants