Skip to content

Commit 6b6c0b2

Browse files
committed
Better way to log learning rate
1 parent da8ff62 commit 6b6c0b2

2 files changed

Lines changed: 13 additions & 30 deletions

File tree

‎resnet_cifar.py‎

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,6 @@
2121
epochs_drop = 5.0
2222

2323

24-
class LRTensorBoard(TensorBoard):
25-
def __init__(self, log_dir, update_freq, histogram_freq): # add other arguments to __init__ if you need
26-
super(LRTensorBoard, self).__init__(log_dir=log_dir,
27-
update_freq=update_freq,
28-
histogram_freq=histogram_freq)
29-
30-
def on_epoch_end(self, epoch, logs=None):
31-
logs.update({'lr': self.model.optimizer.lr})
32-
super(LRTensorBoard, self).on_epoch_end(epoch, logs)
33-
34-
3524
def preprocess(x, y):
3625
x = tf.image.per_image_standardization(x)
3726
return x, y
@@ -55,6 +44,7 @@ def schedule(epoch):
5544
drop = 0.5
5645

5746
lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
47+
tf.summary.scalar('learning rate', data=lrate, step=epoch)
5848
return lrate
5949

6050

@@ -77,28 +67,29 @@ def schedule(epoch):
7767

7868

7969
if num_gpus == 1:
80-
model = resnet.resnet56(classes=num_classes)
70+
model = resnet.resnet32(classes=num_classes)
8171
model.compile(
8272
optimizer=keras.optimizers.Adam(learning_rate=INIT_LR),
8373
loss='sparse_categorical_crossentropy',
8474
metrics=['accuracy'])
8575
else:
8676
mirrored_strategy = tf.distribute.MirroredStrategy()
8777
with mirrored_strategy.scope():
88-
model = resnet.resnet56(classes=num_classes)
78+
model = resnet.resnet32(classes=num_classes)
8979
model.compile(
9080
optimizer=keras.optimizers.Adam(learning_rate=INIT_LR),
9181
loss='sparse_categorical_crossentropy',
9282
metrics=['accuracy'])
9383

9484
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
95-
tensorboard_callback = LRTensorBoard(
85+
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
86+
file_writer.set_as_default()
87+
tensorboard_callback = TensorBoard(
9688
log_dir=log_dir,
9789
update_freq='batch',
9890
histogram_freq=1)
9991

100-
lr_schedule_callback = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
101-
92+
lr_schedule_callback = keras.callbacks.LearningRateScheduler(schedule)
10293

10394
model.fit(train_loader,
10495
epochs=num_epochs,

‎vgg_cifar.py‎

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@
2020
num_epochs = 20
2121
epochs_drop = 5.0
2222

23-
class LRTensorBoard(TensorBoard):
24-
def __init__(self, log_dir, update_freq, histogram_freq): # add other arguments to __init__ if you need
25-
super(LRTensorBoard, self).__init__(log_dir=log_dir,
26-
update_freq=update_freq,
27-
histogram_freq=histogram_freq)
28-
29-
def on_epoch_end(self, epoch, logs=None):
30-
logs.update({'lr': self.model.optimizer.lr})
31-
super(LRTensorBoard, self).on_epoch_end(epoch, logs)
32-
33-
3423
def preprocess(x, y):
3524
x = tf.image.per_image_standardization(x)
3625
return x, y
@@ -55,6 +44,7 @@ def schedule(epoch):
5544
drop = 0.5
5645

5746
lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
47+
tf.summary.scalar('learning rate', data=lrate, step=epoch)
5848
return lrate
5949

6050
def VGG16(input_shape):
@@ -195,14 +185,16 @@ def VGG16(input_shape):
195185
optimizer=keras.optimizers.Adam(learning_rate=INIT_LR),
196186
loss='sparse_categorical_crossentropy',
197187
metrics=['accuracy'])
198-
188+
199189
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
200-
tensorboard_callback = LRTensorBoard(
190+
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
191+
file_writer.set_as_default()
192+
tensorboard_callback = TensorBoard(
201193
log_dir=log_dir,
202194
update_freq='batch',
203195
histogram_freq=1)
204196

205-
lr_schedule_callback = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
197+
lr_schedule_callback = keras.callbacks.LearningRateScheduler(schedule)
206198

207199

208200
model.fit(train_loader,

0 commit comments

Comments
 (0)