Skip to content
Prev Previous commit
Next Next commit
Removing Sequence and Token classification models. Removing integrati…
…on tests for now
  • Loading branch information
RyanMullins committed Sep 12, 2025
commit bdccd85355e8dd70c8b8fb416fea81d1d4d52df9
2 changes: 0 additions & 2 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("tapas", "TapasForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("umt5", "UMT5ForSequenceClassification"),
("vaultgemma", "VaultGemmaForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
Expand Down Expand Up @@ -1493,7 +1492,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("t5", "T5ForTokenClassification"),
("t5gemma", "T5GemmaForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("vaultgemma", "VaultGemmaForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
Expand Down
22 changes: 2 additions & 20 deletions src/transformers/models/vaultgemma/modeling_vaultgemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -572,18 +568,4 @@ def forward(
)


class VaultGemmaForSequenceClassification(GenericForSequenceClassification, VaultGemmaPreTrainedModel):
pass


class VaultGemmaForTokenClassification(GenericForTokenClassification, VaultGemmaPreTrainedModel):
pass


__all__ = [
"VaultGemmaForCausalLM",
"VaultGemmaModel",
"VaultGemmaPreTrainedModel",
"VaultGemmaForSequenceClassification",
"VaultGemmaForTokenClassification",
]
__all__ = ["VaultGemmaForCausalLM", "VaultGemmaModel", "VaultGemmaPreTrainedModel"]
11 changes: 0 additions & 11 deletions src/transformers/models/vaultgemma/modular_vaultgemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2MLP,
Gemma2Model,
Gemma2PreTrainedModel,
Expand Down Expand Up @@ -112,19 +110,10 @@ class VaultGemmaForCausalLM(Gemma2ForCausalLM):
pass


class VaultGemmaForSequenceClassification(Gemma2ForSequenceClassification):
pass


class VaultGemmaForTokenClassification(Gemma2ForTokenClassification):
pass


__all__ = [
"VaultGemmaConfig",
"VaultGemmaForCausalLM",
"VaultGemmaModel",
"VaultGemmaPreTrainedModel",
"VaultGemmaForSequenceClassification",
"VaultGemmaForTokenClassification",
]
Loading