Skip to content

Commit 312daaa

Browse files
Accelerate some adjustment for mixed precision (#1009)
* Use accelerator.autocast when computing loss According to the accelerate docs, loss computation should be performed within the accelerator.autocast context manager: https://huggingface.co/docs/accelerate/v0.21.0/en/quicktour#mixed-precision-training I tested if this makes a difference by running the following notebook with fp16 precision: https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb I found no difference at all: The runtime was practially the same and the losses were identical. Still, I think it's better to have this than not, as it is recommended by the accelerate docs. * Update LR scheduler callback to work w/ accelerate According to the accelerate docs: https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training the LR scheduler step should sometimes be skipped when using mixed precision training because accelerate may skip update steps internally. Therefore, I updated the LR scheduler callback to check if the net has an accelerator and if it does, to check if a step is necessary. This is actually quite hard to test because the necessity of stepping depends on accelerate's internal logic, which we don't want to test, and which might change in the future. Therefore, the added test just runs training with accelerate, mixed precision, and some lr schedulers, verifying that there is no error. When running these tests + the normal lr scheduler tests locally on a machine that supports fp16, I get 100% line coverage of lr_scheduler.py. I think this is good enough. * Non-functional clean ups related to lr schedulers While working on the fixes in this PR, I also cleaned up some lr scheduler code. These clean ups are non-functional. 1. We imported CyclicLR as TorchCyclicLR. I'm not sure why but it is somehow related to very old PyTorch versions we no longer support, so I removed this. 2. Fixed some indentations for conditional checks to improve readability. * Reviewer comment: Simplify conditional code
1 parent 07fc260 commit 312daaa

4 files changed

Lines changed: 102 additions & 29 deletions

File tree

‎skorch/callbacks/lr_scheduler.py‎

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,12 @@
99
import torch
1010
from torch.optim.lr_scheduler import _LRScheduler
1111
from torch.optim.lr_scheduler import CosineAnnealingLR
12+
from torch.optim.lr_scheduler import CyclicLR
1213
from torch.optim.lr_scheduler import ExponentialLR
1314
from torch.optim.lr_scheduler import LambdaLR
1415
from torch.optim.lr_scheduler import MultiStepLR
1516
from torch.optim.lr_scheduler import ReduceLROnPlateau
1617
from torch.optim.lr_scheduler import StepLR
17-
18-
try:
19-
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
20-
except ImportError:
21-
# Backward compatibility with torch >= 1.0 && < 1.1
22-
TorchCyclicLR = None
2318
from torch.optim.optimizer import Optimizer
2419
from skorch.callbacks import Callback
2520

@@ -142,6 +137,31 @@ def on_train_begin(self, net, **kwargs):
142137
net, self.policy_, **self.kwargs
143138
)
144139

140+
def _step(self, net, lr_scheduler, score=None):
141+
"""Helper method to step the lr scheduler.
142+
143+
This takes care of two things:
144+
145+
1. If the lr scheduler is ReduceLROnPlateau, we need to pass the score.
146+
2. If the net is uses AccelerateMixin, stepping has to be skipped in
147+
certain conditions.
148+
149+
For more info on the latter, see:
150+
https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
151+
152+
"""
153+
accelerator_maybe = getattr(net, 'accelerator', None)
154+
accelerator_step_skipped = (
155+
accelerator_maybe and accelerator_maybe.optimizer_step_was_skipped
156+
)
157+
if accelerator_step_skipped:
158+
return
159+
160+
if score is None:
161+
lr_scheduler.step()
162+
else:
163+
lr_scheduler.step(score)
164+
145165
def on_epoch_end(self, net, **kwargs):
146166
if self.step_every != 'epoch':
147167
return
@@ -158,31 +178,36 @@ def on_epoch_end(self, net, **kwargs):
158178
"should be placed before the LRScheduler callback"
159179
) from e
160180

161-
self.lr_scheduler_.step(score)
181+
self._step(net, self.lr_scheduler_, score=score)
162182
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
163183
else:
164-
if self.event_name is not None and hasattr(
165-
self.lr_scheduler_, "get_last_lr"):
166-
net.history.record(self.event_name,
167-
self.lr_scheduler_.get_last_lr()[0])
168-
self.lr_scheduler_.step()
184+
if (
185+
(self.event_name is not None)
186+
and hasattr(self.lr_scheduler_, "get_last_lr")
187+
):
188+
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
189+
self._step(net, self.lr_scheduler_)
169190

170191
def on_batch_end(self, net, training, **kwargs):
171192
if not training or self.step_every != 'batch':
172193
return
173-
if self.event_name is not None and hasattr(
174-
self.lr_scheduler_, "get_last_lr"):
175-
net.history.record_batch(self.event_name,
176-
self.lr_scheduler_.get_last_lr()[0])
177-
self.lr_scheduler_.step()
194+
if (
195+
(self.event_name is not None)
196+
and hasattr(self.lr_scheduler_, "get_last_lr")
197+
):
198+
net.history.record_batch(
199+
self.event_name, self.lr_scheduler_.get_last_lr()[0])
200+
self._step(net, self.lr_scheduler_)
178201
self.batch_idx_ += 1
179202

180203
def _get_scheduler(self, net, policy, **scheduler_kwargs):
181204
"""Return scheduler, based on indicated policy, with appropriate
182205
parameters.
183206
"""
184-
if policy not in [ReduceLROnPlateau] and \
185-
'last_epoch' not in scheduler_kwargs:
207+
if (
208+
(policy not in [ReduceLROnPlateau])
209+
and ('last_epoch' not in scheduler_kwargs)
210+
):
186211
last_epoch = len(net.history) - 1
187212
scheduler_kwargs['last_epoch'] = last_epoch
188213

‎skorch/hf.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,10 @@ def train_step(self, batch, **fit_params):
10051005
def train_step_single(self, batch, **fit_params):
10061006
self._set_training(True)
10071007
Xi, yi = unpack_data(batch)
1008-
y_pred = self.infer(Xi, **fit_params)
1009-
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
1010-
self.accelerator.backward(loss)
1008+
with self.accelerator.autocast():
1009+
y_pred = self.infer(Xi, **fit_params)
1010+
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
1011+
self.accelerator.backward(loss)
10111012
return {
10121013
'loss': loss,
10131014
'y_pred': y_pred,

‎skorch/tests/callbacks/test_lr_scheduler.py‎

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import pytest
6-
import torch
76
from sklearn.base import clone
87
from torch.optim import SGD
98
from torch.optim.lr_scheduler import CosineAnnealingLR
@@ -12,7 +11,7 @@
1211
from torch.optim.lr_scheduler import MultiStepLR
1312
from torch.optim.lr_scheduler import ReduceLROnPlateau
1413
from torch.optim.lr_scheduler import StepLR
15-
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
14+
from torch.optim.lr_scheduler import CyclicLR
1615

1716
from skorch import NeuralNetClassifier
1817
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
@@ -28,7 +27,7 @@ def test_simulate_lrs_epoch_step(self, policy):
2827
expected = np.array([1.0, 1.0, 0.1, 0.1, 0.01, 0.01])
2928
assert np.allclose(expected, lrs)
3029

31-
@pytest.mark.parametrize('policy', [TorchCyclicLR])
30+
@pytest.mark.parametrize('policy', [CyclicLR])
3231
def test_simulate_lrs_batch_step(self, policy):
3332
lr_sch = LRScheduler(
3433
policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch')
@@ -96,7 +95,7 @@ def test_lr_callback_steps_correctly(
9695
assert lr_policy.lr_scheduler_.last_epoch == max_epochs
9796

9897
@pytest.mark.parametrize('policy, kwargs', [
99-
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
98+
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
10099
])
101100
def test_lr_callback_batch_steps_correctly(
102101
self,
@@ -125,7 +124,7 @@ def test_lr_callback_batch_steps_correctly(
125124
assert lr_policy.batch_idx_ == expected
126125

127126
@pytest.mark.parametrize('policy, kwargs', [
128-
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
127+
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
129128
])
130129
def test_lr_callback_batch_steps_correctly_fallback(
131130
self,
@@ -177,7 +176,7 @@ def test_lr_scheduler_cloneable(self):
177176

178177
def test_lr_scheduler_set_params(self, classifier_module, classifier_data):
179178
scheduler = LRScheduler(
180-
TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch')
179+
CyclicLR, base_lr=123, max_lr=999, step_every='batch')
181180
net = NeuralNetClassifier(
182181
classifier_module,
183182
max_epochs=0,
@@ -212,7 +211,7 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data
212211
batch_size = 128
213212

214213
scheduler = LRScheduler(
215-
TorchCyclicLR,
214+
CyclicLR,
216215
base_lr=1,
217216
max_lr=5,
218217
step_size_up=4,

‎skorch/tests/test_hf.py‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ class MockAccelerator:
802802
def __init__(self):
803803
self.device_placement = True
804804
self.print = print
805+
self.optimizer_step_was_skipped = False
805806

806807
def prepare(self, *args):
807808
for arg in args:
@@ -826,6 +827,11 @@ def wait_for_everyone(self):
826827
def accumulate(self, model):
827828
yield
828829

830+
# pylint: disable=unused-argument
831+
@contextmanager
832+
def autocast(self, cache_enabled=False, autocast_handler=None):
833+
yield
834+
829835
# pylint: disable=missing-docstring,arguments-differ
830836
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
831837
def get_iterator(self, *args, **kwargs):
@@ -950,6 +956,48 @@ def train_step(self, *args, **kwargs):
950956
updated_expected = [False, False, True, False, False, True, True] * max_epochs
951957
assert updated == updated_expected
952958

959+
@pytest.mark.parametrize('mixed_precision', ['no', 'fp16', 'bf16'])
960+
@pytest.mark.parametrize('scheduler', ['ReduceLROnPlateau', 'StepLR'])
961+
def test_lr_scheduler_with_accelerate(
962+
self, net_cls, accelerator_cls, data, mixed_precision, scheduler
963+
):
964+
# This test only checks that lr schedulers work with accelerate mixed
965+
# precision. The reason why this requires special handling is explained
966+
# here:
967+
# https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
968+
# There is no test for whether the lr scheduler actually steps correctly
969+
# or not, as that would require knowledge of accelerate internals, which
970+
# we don't want to rely on.
971+
from accelerate.utils import is_bf16_available
972+
from skorch.callbacks import LRScheduler
973+
974+
if (mixed_precision != 'no') and not torch.cuda.is_available():
975+
pytest.skip('skipping AMP test because device does not support it')
976+
if (mixed_precision == 'bf16') and not is_bf16_available():
977+
pytest.skip('skipping bf16 test because device does not support it')
978+
979+
X, y = data[0][:100], data[1][:100]
980+
max_epochs = 10
981+
982+
if scheduler == 'ReduceLROnPlateau':
983+
lr_scheduler = LRScheduler(
984+
policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
985+
)
986+
else:
987+
lr_scheduler = LRScheduler(
988+
policy=torch.optim.lr_scheduler.StepLR,
989+
step_size=2,
990+
step_every='batch',
991+
)
992+
993+
accelerator = accelerator_cls()
994+
net = net_cls(
995+
accelerator=accelerator,
996+
max_epochs=max_epochs,
997+
callbacks=[lr_scheduler],
998+
)
999+
net.fit(X, y)
1000+
9531001

9541002
class MockHfApi:
9551003
"""Mock of huggingface_hub.HfAPI"""

0 commit comments

Comments
 (0)