Skip to content

Commit 6e7969d

Browse files
committed
Save checkpoint
1 parent 48e7a1f commit 6e7969d

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ venv-tf2/*
22
__pycache__/*
33
*.h5
44
logs/*
5+
log/*

‎resnet_cifar.py‎

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
BASE_LEARNING_RATE = 0.1
2424
LR_SCHEDULE = [(0.1, 4), (0.01, 8), (0.001, 10)]
25-
25+
MODEL_PATH = "log/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5"
2626

2727
def preprocess(x, y):
2828
x = tf.image.per_image_standardization(x)
@@ -82,12 +82,15 @@ def schedule(epoch):
8282
histogram_freq=1)
8383

8484
lr_schedule_callback = keras.callbacks.LearningRateScheduler(schedule)
85+
ckpt_callback = keras.callbacks.ModelCheckpoint(
86+
MODEL_PATH, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
87+
8588

8689
model.fit(train_loader,
8790
epochs=num_epochs,
8891
validation_data=test_loader,
89-
validation_freq=10,
90-
callbacks=[tensorboard_callback, lr_schedule_callback])
92+
validation_freq=1,
93+
callbacks=[tensorboard_callback, lr_schedule_callback, ckpt_callback])
9194
model.evaluate(test_loader)
9295

9396
model.save('model.h5')

0 commit comments

Comments
 (0)