Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/gemma3/configuration_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def __init__(
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.use_bidirectional_attention = use_bidirectional_attention
if use_bidirectional_attention:
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds

self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = config.query_pre_attn_scalar**-0.5
self.attention_dropout = self.config.attention_dropout
self.is_causal = True
self.is_causal = not self.config.use_bidirectional_attention

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
Expand Down Expand Up @@ -450,8 +450,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
"""A token can attend to any other token if their absolute distance is within
half the sliding window size (distance <= sliding_window // 2)."""
return abs(q_idx - kv_idx) <= sliding_window // 2
the (exclusive) sliding window size (distance < sliding_window)."""
return abs(q_idx - kv_idx) < sliding_window

return inner_mask

Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def __init__(
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.use_bidirectional_attention = use_bidirectional_attention
if use_bidirectional_attention:
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds

self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
Expand Down Expand Up @@ -402,6 +404,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):

super().__init__(config, layer_idx)
self.sliding_window = config.sliding_window if self.is_sliding else None
self.is_causal = not self.config.use_bidirectional_attention

self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -546,8 +549,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
"""A token can attend to any other token if their absolute distance is within
half the sliding window size (distance <= sliding_window // 2)."""
return abs(q_idx - kv_idx) <= sliding_window // 2
the (exclusive) sliding window size (distance < sliding_window)."""
return abs(q_idx - kv_idx) < sliding_window

return inner_mask

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma3n/modular_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,7 @@ def apply_rotary_pos_emb(
class Gemma3nTextAttention(Gemma3Attention):
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.is_causal = True
del self.attn_logit_softcapping
del self.scaling
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
Expand Down