Skip to content

Commit 114c656

Browse files
committed
remove checkpoint
1 parent 6e7969d commit 114c656

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

‎resnet_cifar.py‎

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
BASE_LEARNING_RATE = 0.1
2424
LR_SCHEDULE = [(0.1, 4), (0.01, 8), (0.001, 10)]
25-
MODEL_PATH = "log/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5"
2625

2726
def preprocess(x, y):
2827
x = tf.image.per_image_standardization(x)
@@ -57,19 +56,20 @@ def schedule(epoch):
5756
train_loader = train_loader.map(augmentation).map(preprocess).shuffle(num_train_samples).batch(bs_per_gpu * num_gpus, drop_remainder=True)
5857
test_loader = test_loader.map(preprocess).batch(bs_per_gpu * num_gpus, drop_remainder=True)
5958

59+
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
6060

6161
if num_gpus == 1:
6262
model = resnet.resnet56(classes=num_classes)
6363
model.compile(
64-
optimizer=keras.optimizers.SGD(learning_rate=0.1, momentum=0.9),
64+
optimizer=opt,
6565
loss='sparse_categorical_crossentropy',
6666
metrics=['accuracy'])
6767
else:
6868
mirrored_strategy = tf.distribute.MirroredStrategy()
6969
with mirrored_strategy.scope():
7070
model = resnet.resnet56(classes=num_classes)
7171
model.compile(
72-
optimizer=keras.optimizers.SGD(learning_rate=0.1, momentum=0.9),
72+
optimizer=opt,
7373
loss='sparse_categorical_crossentropy',
7474
metrics=['accuracy'])
7575

@@ -82,15 +82,12 @@ def schedule(epoch):
8282
histogram_freq=1)
8383

8484
lr_schedule_callback = keras.callbacks.LearningRateScheduler(schedule)
85-
ckpt_callback = keras.callbacks.ModelCheckpoint(
86-
MODEL_PATH, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
87-
8885

8986
model.fit(train_loader,
9087
epochs=num_epochs,
9188
validation_data=test_loader,
9289
validation_freq=1,
93-
callbacks=[tensorboard_callback, lr_schedule_callback, ckpt_callback])
90+
callbacks=[tensorboard_callback, lr_schedule_callback])
9491
model.evaluate(test_loader)
9592

9693
model.save('model.h5')

0 commit comments

Comments
 (0)