2121epochs_drop = 5.0
2222
2323
24- class LRTensorBoard (TensorBoard ):
25- def __init__ (self , log_dir , update_freq , histogram_freq ): # add other arguments to __init__ if you need
26- super (LRTensorBoard , self ).__init__ (log_dir = log_dir ,
27- update_freq = update_freq ,
28- histogram_freq = histogram_freq )
29-
30- def on_epoch_end (self , epoch , logs = None ):
31- logs .update ({'lr' : self .model .optimizer .lr })
32- super (LRTensorBoard , self ).on_epoch_end (epoch , logs )
33-
34-
3524def preprocess (x , y ):
3625 x = tf .image .per_image_standardization (x )
3726 return x , y
@@ -55,6 +44,7 @@ def schedule(epoch):
5544 drop = 0.5
5645
5746 lrate = initial_lrate * math .pow (drop , math .floor ((1 + epoch )/ epochs_drop ))
47+ tf .summary .scalar ('learning rate' , data = lrate , step = epoch )
5848 return lrate
5949
6050
@@ -77,28 +67,29 @@ def schedule(epoch):
7767
7868
7969if num_gpus == 1 :
80- model = resnet .resnet56 (classes = num_classes )
70+ model = resnet .resnet32 (classes = num_classes )
8171 model .compile (
8272 optimizer = keras .optimizers .Adam (learning_rate = INIT_LR ),
8373 loss = 'sparse_categorical_crossentropy' ,
8474 metrics = ['accuracy' ])
8575else :
8676 mirrored_strategy = tf .distribute .MirroredStrategy ()
8777 with mirrored_strategy .scope ():
88- model = resnet .resnet56 (classes = num_classes )
78+ model = resnet .resnet32 (classes = num_classes )
8979 model .compile (
9080 optimizer = keras .optimizers .Adam (learning_rate = INIT_LR ),
9181 loss = 'sparse_categorical_crossentropy' ,
9282 metrics = ['accuracy' ])
9383
9484log_dir = "logs/fit/" + datetime .datetime .now ().strftime ("%Y%m%d-%H%M%S" )
95- tensorboard_callback = LRTensorBoard (
85+ file_writer = tf .summary .create_file_writer (log_dir + "/metrics" )
86+ file_writer .set_as_default ()
87+ tensorboard_callback = TensorBoard (
9688 log_dir = log_dir ,
9789 update_freq = 'batch' ,
9890 histogram_freq = 1 )
9991
100- lr_schedule_callback = tf .keras .callbacks .LearningRateScheduler (schedule , verbose = 0 )
101-
92+ lr_schedule_callback = keras .callbacks .LearningRateScheduler (schedule )
10293
10394model .fit (train_loader ,
10495 epochs = num_epochs ,
0 commit comments