Skip to content

Commit ed8a761

Browse files
committed
Add Multi-GPU
1 parent 8ec520f commit ed8a761

4 files changed

Lines changed: 222 additions & 1 deletion

File tree

‎.gitignore‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
venv-tf2/*
2+
__pycache__/*

‎README.md‎

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
# tf-2-tutorial
1+
# tf2-tutorial
2+
3+
4+
# Installation
5+
6+
git clone https://github.com/chuanli11/tf2-tutorial.git
7+
cd tf2-tutorial
8+
virtualenv venv-tf2
9+
. venv-tf2/bin/activate
10+
pip install tf-nightly-gpu-2.0-preview
11+
12+
13+
14+
# Run
15+

‎example.py‎

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import tensorflow as tf
3+
from tensorflow import keras
4+
from tensorflow.keras import datasets, layers, optimizers
5+
import argparse
6+
import numpy as np
7+
8+
9+
10+
from network import VGG16
11+
12+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'}
13+
parser = argparse.ArgumentParser()
14+
15+
16+
17+
parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train',
18+
help="Directory where to write event logs and checkpoint.")
19+
parser.add_argument('--max_steps', type=int, default=1000000,
20+
help="""Number of batches to run.""")
21+
parser.add_argument('--log_device_placement', action='store_true',
22+
help="Whether to log device placement.")
23+
parser.add_argument('--log_frequency', type=int, default=10,
24+
help="How often to log results to the console.")
25+
parser.add_argument('--num_gpus', type=int, default=1,
26+
help="How many GPUs to use.")
27+
parser.add_argument('--bs_per_gpu', type=int, default=256,
28+
help="Batch size on each GPU.")
29+
parser.add_argument('--num_epochs', type=int, default=3,
30+
help="Number of training epochs.")
31+
32+
args = parser.parse_args()
33+
34+
def normalize(X_train, X_test):
35+
# this function normalize inputs for zero mean and unit variance
36+
# it is used when training a model.
37+
# Input: training set and test set
38+
# Output: normalized training set and test set according to the trianing set statistics.
39+
X_train = X_train / 255.
40+
X_test = X_test / 255.
41+
42+
mean = np.mean(X_train, axis=(0, 1, 2, 3))
43+
std = np.std(X_train, axis=(0, 1, 2, 3))
44+
print('mean:', mean, 'std:', std)
45+
X_train = (X_train - mean) / (std + 1e-7)
46+
X_test = (X_test - mean) / (std + 1e-7)
47+
return X_train, X_test
48+
49+
def prepare_cifar(x, y):
50+
51+
x = tf.cast(x, tf.float32)
52+
y = tf.cast(y, tf.int32)
53+
return x, y
54+
55+
56+
def main():
57+
58+
tf.random.set_seed(22)
59+
60+
print('loading data...')
61+
(x,y), (x_test, y_test) = datasets.cifar10.load_data()
62+
x, x_test = normalize(x, x_test)
63+
64+
train_loader = tf.data.Dataset.from_tensor_slices((x,y))
65+
test_loader = tf.data.Dataset.from_tensor_slices((x_test, y_test))
66+
67+
train_loader = train_loader.map(prepare_cifar).shuffle(50000).batch(args.bs_per_gpu * args.num_gpus)
68+
test_loader = test_loader.map(prepare_cifar).shuffle(10000).batch(args.bs_per_gpu * args.num_gpus)
69+
70+
71+
if args.num_gpus == 1:
72+
model = VGG16([32, 32, 3])
73+
model.compile(
74+
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
75+
loss=keras.losses.SparseCategoricalCrossentropy(),
76+
metrics=['accuracy'])
77+
78+
else:
79+
mirrored_strategy = tf.distribute.MirroredStrategy()
80+
with mirrored_strategy.scope():
81+
model = VGG16([32, 32, 3])
82+
model.compile(
83+
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
84+
loss=keras.losses.SparseCategoricalCrossentropy(),
85+
metrics=['accuracy'])
86+
87+
model.fit(train_loader, epochs=args.num_epochs)
88+
model.evaluate(train_loader)
89+
90+
91+
if __name__ == '__main__':
92+
main()

‎network.py‎

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from tensorflow.keras import datasets, layers, optimizers, models
4+
from tensorflow.keras import regularizers
5+
6+
7+
class VGG16(models.Model):
8+
9+
10+
def __init__(self, input_shape):
11+
"""
12+
13+
:param input_shape: [32, 32, 3]
14+
"""
15+
super(VGG16, self).__init__()
16+
17+
weight_decay = 0.000
18+
self.num_classes = 10
19+
20+
model = models.Sequential()
21+
22+
model.add(layers.Conv2D(64, (3, 3), padding='same',
23+
input_shape=input_shape, kernel_regularizer=regularizers.l2(weight_decay)))
24+
model.add(layers.Activation('relu'))
25+
model.add(layers.BatchNormalization())
26+
model.add(layers.Dropout(0.3))
27+
28+
model.add(layers.Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
29+
model.add(layers.Activation('relu'))
30+
model.add(layers.BatchNormalization())
31+
32+
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
33+
34+
model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
35+
model.add(layers.Activation('relu'))
36+
model.add(layers.BatchNormalization())
37+
model.add(layers.Dropout(0.4))
38+
39+
model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
40+
model.add(layers.Activation('relu'))
41+
model.add(layers.BatchNormalization())
42+
43+
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
44+
45+
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
46+
model.add(layers.Activation('relu'))
47+
model.add(layers.BatchNormalization())
48+
model.add(layers.Dropout(0.4))
49+
50+
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
51+
model.add(layers.Activation('relu'))
52+
model.add(layers.BatchNormalization())
53+
model.add(layers.Dropout(0.4))
54+
55+
model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
56+
model.add(layers.Activation('relu'))
57+
model.add(layers.BatchNormalization())
58+
59+
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
60+
61+
62+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
63+
model.add(layers.Activation('relu'))
64+
model.add(layers.BatchNormalization())
65+
model.add(layers.Dropout(0.4))
66+
67+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
68+
model.add(layers.Activation('relu'))
69+
model.add(layers.BatchNormalization())
70+
model.add(layers.Dropout(0.4))
71+
72+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
73+
model.add(layers.Activation('relu'))
74+
model.add(layers.BatchNormalization())
75+
76+
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
77+
78+
79+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
80+
model.add(layers.Activation('relu'))
81+
model.add(layers.BatchNormalization())
82+
model.add(layers.Dropout(0.4))
83+
84+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
85+
model.add(layers.Activation('relu'))
86+
model.add(layers.BatchNormalization())
87+
model.add(layers.Dropout(0.4))
88+
89+
model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))
90+
model.add(layers.Activation('relu'))
91+
model.add(layers.BatchNormalization())
92+
93+
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
94+
model.add(layers.Dropout(0.5))
95+
96+
model.add(layers.Flatten())
97+
model.add(layers.Dense(512,kernel_regularizer=regularizers.l2(weight_decay)))
98+
model.add(layers.Activation('relu'))
99+
model.add(layers.BatchNormalization())
100+
101+
model.add(layers.Dropout(0.5))
102+
model.add(layers.Dense(self.num_classes))
103+
model.add(layers.Activation('softmax'))
104+
105+
106+
self.model = model
107+
108+
109+
def call(self, x):
110+
111+
x = self.model(x)
112+
113+
return x

0 commit comments

Comments
 (0)