Vectorize CPU-path preprocessing and decoding hot loops#333
Merged
Conversation
Replace Python-level bottlenecks with batched tensor operations: - decoder.py: Extract all span scores via single advanced-index call instead of N per-span .item() GPU→CPU syncs; vectorize valid-span check; cache greedy-search tuples to avoid re-creation in inner loop - modeling/utils.py: Replace O(E²) nested Python loop for entity-pair generation with torch.meshgrid + diagonal mask - data_processing/utils.py: Return LongTensor directly from prepare_span_idx using torch.arange broadcasting instead of list comprehension + later conversion - data_processing/processor.py: Replace per-element .item() dict comprehensions with single .tolist() batch conversions; update callers to consume tensor from prepare_span_idx directly All outputs verified identical across 8 equivalence tests including full end-to-end inference on CPU and GPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Collaborator
|
Thank you @maxwbuckley for your contribution |
11 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR replaces several Python-level bottlenecks in preprocessing and postprocessing with batched tensor operations. All changes target the ~20% of wall time spent outside the GPU forward pass — the CPU overhead in span decoding, entity pair generation, span index construction, and label creation.
No observable behaviour is changed. All outputs are bit-identical to the baseline, verified across 8 equivalence tests including full end-to-end inference on both CPU and GPU.
Changes
1. Vectorize decode
.item()loop (decoding/decoder.py)The
_decode_batch_itemhot loop previously called.item()per candidate span, each forcing a GPU→CPU synchronization barrier:Now extracts all scores in a single advanced-indexing call:
Also replaces the per-span
_is_valid_span()call with a vectorized boolean mask.2. Optimize greedy overlap removal (
decoding/decoder.py)(start, end, entity_type)fromSpanobjects on every inner-loop comparisonlist.sort()instead ofsorted()for the final ordering3. Vectorize entity pair generation (
modeling/utils.py)Replaces an O(E²) nested Python loop that builds pair lists element-by-element:
With a single
torch.meshgrid+ diagonal mask — zero Python iteration:4. Tensor-native span index generation (
data_processing/utils.py)prepare_span_idxnow returns atorch.LongTensordirectly viatorch.arangebroadcasting instead of a Python list comprehension that callers immediately converted to a tensor anyway. All callers inprocessor.pyupdated to consume the tensor directly, eliminating redundanttorch.LongTensor()wrapping.5. Batch
.tolist()for dict construction (data_processing/processor.py)Replaces per-element
.item()calls inspan_to_index/span_to_idxdict comprehensions:Micro-benchmark results (isolated code paths)
Each optimized function was benchmarked in isolation against its baseline implementation (n=50–200 reps, CPU unless noted):
The decode
.item()elimination is the biggest win — each.item()on a GPU tensor forces a full CUDA synchronization, so replacing N of them with one vectorized indexing call gives 39× on CPU tensors and 242× on GPU tensors.End-to-end impact
End-to-end inference on
gliner_medium-v2.1(5 texts × 6 labels, n=10 reps, RTX 5090):This matches the expected ceiling: the model forward pass accounts for ~80% of wall time, so even eliminating all CPU overhead would yield at most ~1.2× end-to-end. On GPU the forward pass dominates even further, making the overhead savings invisible at the end-to-end level. The gains are real but modest in absolute terms — most visible on CPU inference and on longer sequences where preprocessing/postprocessing scales up.
Correctness verification
All outputs verified identical via 8 dedicated equivalence tests:
prepare_span_idx— tensor values match list-of-tuples baseline for all parameter combinationsspan_to_indexdict — keys and values identical between.item()and.tolist()constructionrowsandcolstensors identical (meshgrid vs nested loop) on both CPU and CUDAbuild_entity_pairsfull function — shapes, masks, pair indices, and no-self-pair invariant verified across batch sizes, entity counts, and thresholds_decode_batch_item— every span's start, end, entity_type, and score (to <1e-7) matches baseline after greedy searchgreedy_search— output identical for 0–100 random spans across all flat_ner/multi_label combinationsprepare_span_labels— label tensors and index tensors identical for tensor vs list inputpredict_entitiesreturns identical text, label, and score (to <1e-6) on both CPU and CUDAExisting test suite: 94 relevant tests pass. 4 pre-existing failures in
TestGreedySearch(tests pass tuples butgreedy_searchexpectsSpandataclass objects — predates this PR).Test plan
🤖 Generated with Claude Code