Skip to content

Vectorize CPU-path preprocessing and decoding hot loops#333

Merged
Ingvarstep merged 1 commit into
urchade:mainfrom
maxwbuckley:main
Feb 23, 2026
Merged

Vectorize CPU-path preprocessing and decoding hot loops#333
Ingvarstep merged 1 commit into
urchade:mainfrom
maxwbuckley:main

Conversation

@maxwbuckley

Copy link
Copy Markdown
Contributor

Summary

This PR replaces several Python-level bottlenecks in preprocessing and postprocessing with batched tensor operations. All changes target the ~20% of wall time spent outside the GPU forward pass — the CPU overhead in span decoding, entity pair generation, span index construction, and label creation.

No observable behaviour is changed. All outputs are bit-identical to the baseline, verified across 8 equivalence tests including full end-to-end inference on both CPU and GPU.

Changes

1. Vectorize decode .item() loop (decoding/decoder.py)

The _decode_batch_item hot loop previously called .item() per candidate span, each forcing a GPU→CPU synchronization barrier:

# Before: N individual GPU→CPU syncs
for s, k, c in zip(s_idx.tolist(), k_idx.tolist(), c_idx.tolist()):
    score = probs_i[s, k, c].item()

Now extracts all scores in a single advanced-indexing call:

# After: 1 GPU→CPU transfer
scores = probs_i[s_idx, k_idx, c_idx].tolist()
flat_idxs = (s_idx * K + k_idx).tolist()

Also replaces the per-span _is_valid_span() call with a vectorized boolean mask.

2. Optimize greedy overlap removal (decoding/decoder.py)

  • Caches selected span tuples in a parallel list instead of re-extracting (start, end, entity_type) from Span objects on every inner-loop comparison
  • Adds early return for empty input
  • Uses in-place list.sort() instead of sorted() for the final ordering

3. Vectorize entity pair generation (modeling/utils.py)

Replaces an O(E²) nested Python loop that builds pair lists element-by-element:

# Before: 10,000 Python iterations for E=100
for i in range(E):
    for j in range(E):
        if i != j:
            all_rows.append(i)
            all_cols.append(j)
rows = torch.tensor(all_rows, ...)

With a single torch.meshgrid + diagonal mask — zero Python iteration:

# After: pure tensor ops
grid_i, grid_j = torch.meshgrid(arange, arange, indexing='ij')
off_diag = grid_i != grid_j
rows, cols = grid_i[off_diag], grid_j[off_diag]

4. Tensor-native span index generation (data_processing/utils.py)

prepare_span_idx now returns a torch.LongTensor directly via torch.arange broadcasting instead of a Python list comprehension that callers immediately converted to a tensor anyway. All callers in processor.py updated to consume the tensor directly, eliminating redundant torch.LongTensor() wrapping.

5. Batch .tolist() for dict construction (data_processing/processor.py)

Replaces per-element .item() calls in span_to_index / span_to_idx dict comprehensions:

# Before: 2N .item() calls
{(t[idx, 0].item(), t[idx, 1].item()): idx for idx in range(len(t))}
# After: 1 .tolist() call
{(s, e): idx for idx, (s, e) in enumerate(t.tolist())}

Micro-benchmark results (isolated code paths)

Each optimized function was benchmarked in isolation against its baseline implementation (n=50–200 reps, CPU unless noted):

Optimization Baseline (median) Optimized (median) Speedup
Decode score extraction (CPU, 1808 spans) 2.66 ms 0.07 ms 39×
Decode score extraction (GPU tensor, 1808 spans) 69.16 ms 0.29 ms 242×
Entity pair gen (E=50) 0.15 ms 0.03 ms 4.9×
Entity pair gen (E=100) 0.61 ms 0.09 ms 6.6×
Span index gen (128 tokens) 0.12 ms 0.01 ms 9.1×
Span index gen (256 tokens) 0.24 ms 0.01 ms 16.2×
span_to_index dict (600 spans) 1.38 ms 0.04 ms 36.9×
span_to_index dict (1536 spans) 3.55 ms 0.10 ms 34.8×

The decode .item() elimination is the biggest win — each .item() on a GPU tensor forces a full CUDA synchronization, so replacing N of them with one vectorized indexing call gives 39× on CPU tensors and 242× on GPU tensors.

End-to-end impact

End-to-end inference on gliner_medium-v2.1 (5 texts × 6 labels, n=10 reps, RTX 5090):

Device Baseline (median) Optimized (median) Speedup
CPU ~1169 ms ~1033 ms ~1.13×
GPU ~81–94 ms ~85–89 ms ~1.0× (within noise)

This matches the expected ceiling: the model forward pass accounts for ~80% of wall time, so even eliminating all CPU overhead would yield at most ~1.2× end-to-end. On GPU the forward pass dominates even further, making the overhead savings invisible at the end-to-end level. The gains are real but modest in absolute terms — most visible on CPU inference and on longer sequences where preprocessing/postprocessing scales up.

Correctness verification

All outputs verified identical via 8 dedicated equivalence tests:

  1. prepare_span_idx — tensor values match list-of-tuples baseline for all parameter combinations
  2. span_to_index dict — keys and values identical between .item() and .tolist() construction
  3. Entity pair generation — rows and cols tensors identical (meshgrid vs nested loop) on both CPU and CUDA
  4. build_entity_pairs full function — shapes, masks, pair indices, and no-self-pair invariant verified across batch sizes, entity counts, and thresholds
  5. _decode_batch_item — every span's start, end, entity_type, and score (to <1e-7) matches baseline after greedy search
  6. greedy_search — output identical for 0–100 random spans across all flat_ner/multi_label combinations
  7. prepare_span_labels — label tensors and index tensors identical for tensor vs list input
  8. End-to-end inference — predict_entities returns identical text, label, and score (to <1e-6) on both CPU and CUDA

Existing test suite: 94 relevant tests pass. 4 pre-existing failures in TestGreedySearch (tests pass tuples but greedy_search expects Span dataclass objects — predates this PR).

Test plan

  • All 8 output equivalence tests pass
  • 94 existing unit tests pass (test_data_processing, test_decoder, test_modeling)
  • Micro-benchmarks confirm speedup in each optimized path
  • End-to-end inference produces identical entity predictions
  • Reviewer: confirm no regressions on your hardware / model variants

🤖 Generated with Claude Code

Replace Python-level bottlenecks with batched tensor operations:

- decoder.py: Extract all span scores via single advanced-index call
  instead of N per-span .item() GPU→CPU syncs; vectorize valid-span
  check; cache greedy-search tuples to avoid re-creation in inner loop
- modeling/utils.py: Replace O(E²) nested Python loop for entity-pair
  generation with torch.meshgrid + diagonal mask
- data_processing/utils.py: Return LongTensor directly from
  prepare_span_idx using torch.arange broadcasting instead of list
  comprehension + later conversion
- data_processing/processor.py: Replace per-element .item() dict
  comprehensions with single .tolist() batch conversions; update
  callers to consume tensor from prepare_span_idx directly

All outputs verified identical across 8 equivalence tests including
full end-to-end inference on CPU and GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Ingvarstep

Copy link
Copy Markdown
Collaborator

Thank you @maxwbuckley for your contribution

@Ingvarstep Ingvarstep merged commit 9d54c32 into urchade:main Feb 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants