2222
2323BASE_LEARNING_RATE = 0.1
2424LR_SCHEDULE = [(0.1 , 4 ), (0.01 , 8 ), (0.001 , 10 )]
25- MODEL_PATH = "log/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5"
2625
2726def preprocess (x , y ):
2827 x = tf .image .per_image_standardization (x )
@@ -57,19 +56,20 @@ def schedule(epoch):
5756train_loader = train_loader .map (augmentation ).map (preprocess ).shuffle (num_train_samples ).batch (bs_per_gpu * num_gpus , drop_remainder = True )
5857test_loader = test_loader .map (preprocess ).batch (bs_per_gpu * num_gpus , drop_remainder = True )
5958
59+ opt = keras .optimizers .SGD (learning_rate = 0.1 , momentum = 0.9 )
6060
6161if num_gpus == 1 :
6262 model = resnet .resnet56 (classes = num_classes )
6363 model .compile (
64- optimizer = keras . optimizers . SGD ( learning_rate = 0.1 , momentum = 0.9 ) ,
64+ optimizer = opt ,
6565 loss = 'sparse_categorical_crossentropy' ,
6666 metrics = ['accuracy' ])
6767else :
6868 mirrored_strategy = tf .distribute .MirroredStrategy ()
6969 with mirrored_strategy .scope ():
7070 model = resnet .resnet56 (classes = num_classes )
7171 model .compile (
72- optimizer = keras . optimizers . SGD ( learning_rate = 0.1 , momentum = 0.9 ) ,
72+ optimizer = opt ,
7373 loss = 'sparse_categorical_crossentropy' ,
7474 metrics = ['accuracy' ])
7575
@@ -82,15 +82,12 @@ def schedule(epoch):
8282 histogram_freq = 1 )
8383
8484lr_schedule_callback = keras .callbacks .LearningRateScheduler (schedule )
85- ckpt_callback = keras .callbacks .ModelCheckpoint (
86- MODEL_PATH , monitor = 'val_loss' , verbose = 0 , save_best_only = False , save_weights_only = False , mode = 'auto' , period = 1 )
87-
8885
8986model .fit (train_loader ,
9087 epochs = num_epochs ,
9188 validation_data = test_loader ,
9289 validation_freq = 1 ,
93- callbacks = [tensorboard_callback , lr_schedule_callback , ckpt_callback ])
90+ callbacks = [tensorboard_callback , lr_schedule_callback ])
9491model .evaluate (test_loader )
9592
9693model .save ('model.h5' )
0 commit comments