Skip to content

Composite RoPE partial-rotary mode pairs contiguously — incompatible with transformers partial/proportional rotary (partial_rotary_factor models) #66

Description

@kylejfrost

Summary

The composite RoPE partial-rotary mode (coreai_torch.composite_ops.RoPE with dims < head_dim, as wrapped by coreai_models/primitives/macos/rope.py) pairs dimensions in a contiguous block (dim ii + dims/2, inside the first dims dims, passing the rest through). HuggingFace transformers' partial / "proportional" rotary (used by any model with partial_rotary_factor < 1) instead pairs across the full head_dim half-split (dim ii + head_dim/2), with only the first rope_angles frequencies non-zero (inv_freq zero-padded). The frequencies are identical; only the dim pairing differs, so the result is silently wrong for partial-rotary models.

Surfaced while porting Gemma-4 (partial_rotary_factor = 0.25, global_head_dim = 512) — its global/full-attention layers use partial rotary. (Gemma-4 isn't in this repo yet; filing because the issue is in the shared primitive and affects any partial-rotary model.)

Where

A recipe that folds the proportional factor into the composite as RoPE(base**prf, dims=2*rope_angles) gets the contiguous-block partial mode (_rope_with_cos_and_sin_impl), not the transformers layout.

Repro (bit-level)

Apply the composite partial RoPE vs transformers Gemma4TextRotaryEmbedding(..., "full_attention") + apply_rotary_pos_emb to the same tensor at any position > 0:

  • Sliding (full-rotary) layers match bit-exact (PSNR ∞).
  • Global (partial-rotary) layer diverges: PSNR ≈ 21.6 dB, max-abs ≈ 8.2 (gemma-4-26B-A4B, head_dim=512, partial_rotary_factor=0.25 → it rotates dims {0..63}∪{256..319}, not {0..127}).

Impact

Silently degrades every full-attention layer of the exported model. Generation stays coherent (global layers are ~1/6 of the stack), so it passes a smoke test — but it isn't faithful to the reference, and it breaks EAGLE/MTP speculative-draft acceptance (a draft trained against the reference layout stops matching the mis-rotated target).

Reference fix (verified bit-exact)

Replace the partial branch with a small module that reproduces transformers exactly:

inv_freq = cat([1/base**(arange(0, 2*rope_angles, 2)/head_dim),
                zeros(head_dim//2 - rope_angles)])           # zero-padded, full-head
ang = positions[..., None] * inv_freq
emb = cat([ang, ang], dim=-1)                                # full-head half-split
x_rot = x * emb.cos() + rotate_half(x) * emb.sin()

After this, both sliding and global RoPE match the reference bit-exact (PSNR ∞).

Secondary (related precision footgun)

If inv_freq is stored as a registered buffer, model.to(bfloat16) downcasts it and bf16's ~3-digit mantissa corrupts the frequencies (cos error ≈ 0.35 at position 200). Recomputing inv_freq in fp32 inside forward avoids it. Worth a guard/note for any rotary buffer in a bf16-compute path.


Found while building caix (native Core AI serving + speculative decoding). Happy to provide the full repro script or open a PR if Gemma-4 / a partial-rotary model lands upstream.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions