|
| 1 | +""" |
| 2 | +This is a simple example of transfer learning using VGG. |
| 3 | +Fine tune a CNN from a classifier to regressor. |
| 4 | +Generate some fake data for describing cat and tiger length. |
| 5 | +
|
| 6 | +Fake length setting: |
| 7 | +Cat - Normal distribution (40, 8) |
| 8 | +Tiger - Normal distribution (100, 30) |
| 9 | +
|
| 10 | +The VGG model and parameters are adopted from: |
| 11 | +https://github.com/machrisaa/tensorflow-vgg |
| 12 | +
|
| 13 | +Learn more, visit my tutorial site: [莫烦Python](https://morvanzhou.github.io) |
| 14 | +""" |
| 15 | + |
| 16 | +from urllib.request import urlretrieve |
| 17 | +import os |
| 18 | +import numpy as np |
| 19 | +import tensorflow as tf |
| 20 | +import skimage.io |
| 21 | +import skimage.transform |
| 22 | +import matplotlib.pyplot as plt |
| 23 | + |
| 24 | + |
| 25 | +def download(): # download tiger and kittycat image |
| 26 | + categories = ['tiger', 'kittycat'] |
| 27 | + for category in categories: |
| 28 | + os.makedirs('./for_transfer_learning/data/%s' % category, exist_ok=True) |
| 29 | + with open('./for_transfer_learning/imagenet_%s.txt' % category, 'r') as file: |
| 30 | + urls = file.readlines() |
| 31 | + n_urls = len(urls) |
| 32 | + for i, url in enumerate(urls): |
| 33 | + try: |
| 34 | + urlretrieve(url.strip(), './data/%s/%s' % (category, url.strip().split('/')[-1])) |
| 35 | + print('%s %i/%i' % (category, i, n_urls)) |
| 36 | + except: |
| 37 | + print('%s %i/%i' % (category, i, n_urls), 'no image') |
| 38 | + |
| 39 | + |
| 40 | +def load_img(path): |
| 41 | + img = skimage.io.imread(path) |
| 42 | + img = img / 255.0 |
| 43 | + # print "Original Image Shape: ", img.shape |
| 44 | + # we crop image from center |
| 45 | + short_edge = min(img.shape[:2]) |
| 46 | + yy = int((img.shape[0] - short_edge) / 2) |
| 47 | + xx = int((img.shape[1] - short_edge) / 2) |
| 48 | + crop_img = img[yy: yy + short_edge, xx: xx + short_edge] |
| 49 | + # resize to 224, 224 |
| 50 | + resized_img = skimage.transform.resize(crop_img, (224, 224))[None, :, :, :] # shape [1, 224, 224, 3] |
| 51 | + return resized_img |
| 52 | + |
| 53 | + |
| 54 | +def load_data(): |
| 55 | + imgs = {'tiger': [], 'kittycat': []} |
| 56 | + for k in imgs.keys(): |
| 57 | + dir = './data/' + k |
| 58 | + for file in os.listdir(dir): |
| 59 | + if not file.lower().endswith('.jpg'): |
| 60 | + continue |
| 61 | + try: |
| 62 | + resized_img = load_img(os.path.join(dir, file)) |
| 63 | + except OSError: |
| 64 | + continue |
| 65 | + imgs[k].append(resized_img) # [1, height, width, depth] * n |
| 66 | + if len(imgs[k]) == 400: # only use 400 imgs to reduce my memory load |
| 67 | + break |
| 68 | + # fake length data for tiger and cat |
| 69 | + tigers_y = np.maximum(20, np.random.randn(len(imgs['tiger']), 1) * 30 + 100) |
| 70 | + cat_y = np.maximum(10, np.random.randn(len(imgs['kittycat']), 1) * 8 + 40) |
| 71 | + return imgs['tiger'], imgs['kittycat'], tigers_y, cat_y |
| 72 | + |
| 73 | + |
| 74 | +class Vgg16: |
| 75 | + vgg_mean = [103.939, 116.779, 123.68] |
| 76 | + |
| 77 | + def __init__(self, vgg16_npy_path=None, restore_from=None): |
| 78 | + # pre-trained parameters |
| 79 | + try: |
| 80 | + self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item() |
| 81 | + except FileNotFoundError: |
| 82 | + print('Please download VGG16 parameters at here https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM') |
| 83 | + |
| 84 | + self.tfx = tf.placeholder(tf.float32, [None, 224, 224, 3]) |
| 85 | + self.tfy = tf.placeholder(tf.float32, [None, 1]) |
| 86 | + |
| 87 | + # Convert RGB to BGR |
| 88 | + red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=self.tfx * 255.0) |
| 89 | + bgr = tf.concat(axis=3, values=[ |
| 90 | + blue - self.vgg_mean[0], |
| 91 | + green - self.vgg_mean[1], |
| 92 | + red - self.vgg_mean[2], |
| 93 | + ]) |
| 94 | + |
| 95 | + # pre-trained VGG layers are fixed in fine-tune |
| 96 | + conv1_1 = self.conv_layer(bgr, "conv1_1") |
| 97 | + conv1_2 = self.conv_layer(conv1_1, "conv1_2") |
| 98 | + pool1 = self.max_pool(conv1_2, 'pool1') |
| 99 | + |
| 100 | + conv2_1 = self.conv_layer(pool1, "conv2_1") |
| 101 | + conv2_2 = self.conv_layer(conv2_1, "conv2_2") |
| 102 | + pool2 = self.max_pool(conv2_2, 'pool2') |
| 103 | + |
| 104 | + conv3_1 = self.conv_layer(pool2, "conv3_1") |
| 105 | + conv3_2 = self.conv_layer(conv3_1, "conv3_2") |
| 106 | + conv3_3 = self.conv_layer(conv3_2, "conv3_3") |
| 107 | + pool3 = self.max_pool(conv3_3, 'pool3') |
| 108 | + |
| 109 | + conv4_1 = self.conv_layer(pool3, "conv4_1") |
| 110 | + conv4_2 = self.conv_layer(conv4_1, "conv4_2") |
| 111 | + conv4_3 = self.conv_layer(conv4_2, "conv4_3") |
| 112 | + pool4 = self.max_pool(conv4_3, 'pool4') |
| 113 | + |
| 114 | + conv5_1 = self.conv_layer(pool4, "conv5_1") |
| 115 | + conv5_2 = self.conv_layer(conv5_1, "conv5_2") |
| 116 | + conv5_3 = self.conv_layer(conv5_2, "conv5_3") |
| 117 | + pool5 = self.max_pool(conv5_3, 'pool5') |
| 118 | + |
| 119 | + # detach original VGG fc layers and |
| 120 | + # reconstruct your own fc layers serve for your own purpose |
| 121 | + self.flatten = tf.reshape(pool5, [-1, 7*7*512]) |
| 122 | + self.fc6 = tf.layers.dense(self.flatten, 256, tf.nn.relu, name='fc6') |
| 123 | + self.out = tf.layers.dense(self.fc6, 1, name='out') |
| 124 | + |
| 125 | + self.sess = tf.Session() |
| 126 | + if restore_from: |
| 127 | + saver = tf.train.Saver() |
| 128 | + saver.restore(self.sess, restore_from) |
| 129 | + else: # training graph |
| 130 | + self.loss = tf.losses.mean_squared_error(labels=self.tfy, predictions=self.out) |
| 131 | + self.train_op = tf.train.RMSPropOptimizer(0.001).minimize(self.loss) |
| 132 | + self.sess.run(tf.global_variables_initializer()) |
| 133 | + |
| 134 | + def max_pool(self, bottom, name): |
| 135 | + return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) |
| 136 | + |
| 137 | + def conv_layer(self, bottom, name): |
| 138 | + with tf.variable_scope(name): # filter is constant |
| 139 | + conv = tf.nn.conv2d(bottom, self.data_dict[name][0], [1, 1, 1, 1], padding='SAME') |
| 140 | + lout = tf.nn.relu(tf.nn.bias_add(conv, self.data_dict[name][1])) |
| 141 | + return lout |
| 142 | + |
| 143 | + def train(self, x, y): |
| 144 | + loss, _ = self.sess.run([self.loss, self.train_op], {self.tfx: x, self.tfy: y}) |
| 145 | + return loss |
| 146 | + |
| 147 | + def predict(self, paths): |
| 148 | + fig, axs = plt.subplots(1, 2) |
| 149 | + for i, path in enumerate(paths): |
| 150 | + x = load_img(path) |
| 151 | + length = self.sess.run(self.out, {self.tfx: x}) |
| 152 | + axs[i].imshow(x[0]) |
| 153 | + axs[i].set_title('Len: %.1f cm' % length) |
| 154 | + axs[i].set_xticks(()); axs[i].set_yticks(()) |
| 155 | + plt.show() |
| 156 | + |
| 157 | + def save(self, path='./for_transfer_learning/model/transfer_learn'): |
| 158 | + saver = tf.train.Saver() |
| 159 | + saver.save(self.sess, path, write_meta_graph=False) |
| 160 | + |
| 161 | + |
| 162 | +def train(): |
| 163 | + tigers_x, cats_x, tigers_y, cats_y = load_data() |
| 164 | + |
| 165 | + # plot fake length distribution |
| 166 | + plt.hist(tigers_y, bins=20, label='Tigers') |
| 167 | + plt.hist(cats_y, bins=10, label='Cats') |
| 168 | + plt.legend() |
| 169 | + plt.xlabel('length') |
| 170 | + plt.show() |
| 171 | + |
| 172 | + xs = np.concatenate(tigers_x + cats_x, axis=0) |
| 173 | + ys = np.concatenate((tigers_y, cats_y), axis=0) |
| 174 | + |
| 175 | + vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy') |
| 176 | + print('Net built') |
| 177 | + for i in range(100): |
| 178 | + b_idx = np.random.randint(0, len(xs), 6) |
| 179 | + train_loss = vgg.train(xs[b_idx], ys[b_idx]) |
| 180 | + print(i, 'train loss: ', train_loss) |
| 181 | + |
| 182 | + vgg.save('./for_transfer_learning/model/transfer_learn') # save learned fc layers |
| 183 | + |
| 184 | + |
| 185 | +def eval(): |
| 186 | + vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy', |
| 187 | + restore_from='./for_transfer_learning/model/transfer_learn') |
| 188 | + vgg.predict( |
| 189 | + ['./for_transfer_learning/data/kittycat/000129037.jpg', './for_transfer_learning/data/tiger/391412.jpg']) |
| 190 | + |
| 191 | + |
| 192 | +if __name__ == '__main__': |
| 193 | + # download() |
| 194 | + # train() |
| 195 | + eval() |
0 commit comments