Skip to content

Commit 1de102e

Browse files
committed
Add tensorboard support.
1 parent 93304f2 commit 1de102e

2 files changed

Lines changed: 45 additions & 19 deletions

File tree

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
venv-tf2/*
22
__pycache__/*
33
*.h5
4+
logs/*

‎example.py‎

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tensorflow.keras import datasets, layers, optimizers, models, regularizers
66
import argparse
77
import numpy as np
8+
import datetime
89

910

1011

@@ -42,84 +43,98 @@ def VGG16(input_shape):
4243

4344
model = models.Sequential()
4445

46+
flag_BN = True
47+
4548
model.add(layers.Conv2D(64, (3, 3), padding='same',
4649
input_shape=input_shape, kernel_regularizer=regularizers.l2(weight_decay)))
4750
model.add(layers.Activation('relu'))
48-
model.add(layers.BatchNormalization())
51+
if flag_BN:
52+
model.add(layers.BatchNormalization())
4953
model.add(layers.Dropout(0.3))
5054

5155
model.add(layers.Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
5256
model.add(layers.Activation('relu'))
53-
model.add(layers.BatchNormalization())
54-
57+
if flag_BN:
58+
model.add(layers.BatchNormalization())
5559
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
5660

5761
model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
5862
model.add(layers.Activation('relu'))
59-
model.add(layers.BatchNormalization())
63+
if flag_BN:
64+
model.add(layers.BatchNormalization())
6065
model.add(layers.Dropout(0.4))
6166

6267
model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
6368
model.add(layers.Activation('relu'))
64-
model.add(layers.BatchNormalization())
69+
if flag_BN:
70+
model.add(layers.BatchNormalization())
6571

6672
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
6773

6874
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
6975
model.add(layers.Activation('relu'))
70-
model.add(layers.BatchNormalization())
76+
if flag_BN:
77+
model.add(layers.BatchNormalization())
7178
model.add(layers.Dropout(0.4))
7279

7380
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
7481
model.add(layers.Activation('relu'))
75-
model.add(layers.BatchNormalization())
82+
if flag_BN:
83+
model.add(layers.BatchNormalization())
7684
model.add(layers.Dropout(0.4))
7785

7886
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
7987
model.add(layers.Activation('relu'))
80-
model.add(layers.BatchNormalization())
81-
88+
if flag_BN:
89+
model.add(layers.BatchNormalization())
8290
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
8391

8492

8593
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
8694
model.add(layers.Activation('relu'))
87-
model.add(layers.BatchNormalization())
95+
if flag_BN:
96+
model.add(layers.BatchNormalization())
8897
model.add(layers.Dropout(0.4))
8998

9099
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
91100
model.add(layers.Activation('relu'))
92-
model.add(layers.BatchNormalization())
101+
if flag_BN:
102+
model.add(layers.BatchNormalization())
93103
model.add(layers.Dropout(0.4))
94104

95105
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
96106
model.add(layers.Activation('relu'))
97-
model.add(layers.BatchNormalization())
107+
if flag_BN:
108+
model.add(layers.BatchNormalization())
98109

99110
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
100111

101112

102113
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
103114
model.add(layers.Activation('relu'))
104-
model.add(layers.BatchNormalization())
115+
if flag_BN:
116+
model.add(layers.BatchNormalization())
105117
model.add(layers.Dropout(0.4))
106118

107119
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
108120
model.add(layers.Activation('relu'))
109-
model.add(layers.BatchNormalization())
121+
if flag_BN:
122+
model.add(layers.BatchNormalization())
110123
model.add(layers.Dropout(0.4))
111124

112125
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
113126
model.add(layers.Activation('relu'))
114-
model.add(layers.BatchNormalization())
127+
if flag_BN:
128+
model.add(layers.BatchNormalization())
115129

116130
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
117131
model.add(layers.Dropout(0.5))
118132

119133
model.add(layers.Flatten())
120134
model.add(layers.Dense(512,kernel_regularizer=regularizers.l2(weight_decay)))
121135
model.add(layers.Activation('relu'))
122-
model.add(layers.BatchNormalization())
136+
if flag_BN:
137+
model.add(layers.BatchNormalization())
123138

124139
model.add(layers.Dropout(0.5))
125140
model.add(layers.Dense(num_classes))
@@ -188,19 +203,29 @@ def main():
188203
loss='sparse_categorical_crossentropy',
189204
metrics=['accuracy'])
190205

206+
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
207+
tensorboard_callback = tf.keras.callbacks.TensorBoard(
208+
log_dir=log_dir,
209+
update_freq=args.bs_per_gpu * args.num_gpus * 10,
210+
histogram_freq=1)
211+
212+
191213
model.fit(train_loader,
192214
epochs=args.num_epochs,
193215
validation_data=val_loader,
194-
validation_freq=1)
216+
validation_freq=1,
217+
callbacks=[tensorboard_callback])
195218
model.evaluate(test_loader)
196219

197220
# Save & load weights
198221
# Cannot save model configuration: http://ashokrahulgade.com/coding/keras/Module1.html
199222
# Save weights to disk
200223
model.save('model.h5')
201-
new_model = keras.models.load_model('model.h5')
202-
new_model.evaluate(test_loader)
203224

225+
new_model = keras.models.load_model('model.h5')
226+
# Result will be slightly different if training uses multiple-gpus
227+
# Related to batch normalization layer
228+
new_model.evaluate(test_loader)
204229

205230
if __name__ == '__main__':
206231
main()

0 commit comments

Comments
 (0)