Skip to content

Commit ccced1a

Browse files
committed
transfer learning
1 parent 0f5497a commit ccced1a

8 files changed

Lines changed: 3056 additions & 1 deletion

File tree

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ All methods mentioned below have their video and text tutorial in Chinese. Visit
3636
* [AutoEncoder](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/404_AutoEncoder.py)
3737
* [DQN Reinforcement Learning](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/405_DQN_reinforcement_learning.py)
3838
* [GAN (Generative Adversarial Nets)](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/406_GAN.py) / [Conditional GAN](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/406_conditional_GAN.py)
39-
39+
* [Transfer Learning](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/407_transfer_learning.py)
4040
* Others (WIP)
4141
* [Dropout](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/501_dropout.py)
4242
* [Batch Normalization](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/502_batch_normalization.py)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""
2+
This is a simple example of transfer learning using VGG.
3+
Fine tune a CNN from a classifier to regressor.
4+
Generate some fake data for describing cat and tiger length.
5+
6+
Fake length setting:
7+
Cat - Normal distribution (40, 8)
8+
Tiger - Normal distribution (100, 30)
9+
10+
The VGG model and parameters are adopted from:
11+
https://github.com/machrisaa/tensorflow-vgg
12+
13+
Learn more, visit my tutorial site: [莫烦Python](https://morvanzhou.github.io)
14+
"""
15+
16+
from urllib.request import urlretrieve
17+
import os
18+
import numpy as np
19+
import tensorflow as tf
20+
import skimage.io
21+
import skimage.transform
22+
import matplotlib.pyplot as plt
23+
24+
25+
def download(): # download tiger and kittycat image
26+
categories = ['tiger', 'kittycat']
27+
for category in categories:
28+
os.makedirs('./for_transfer_learning/data/%s' % category, exist_ok=True)
29+
with open('./for_transfer_learning/imagenet_%s.txt' % category, 'r') as file:
30+
urls = file.readlines()
31+
n_urls = len(urls)
32+
for i, url in enumerate(urls):
33+
try:
34+
urlretrieve(url.strip(), './data/%s/%s' % (category, url.strip().split('/')[-1]))
35+
print('%s %i/%i' % (category, i, n_urls))
36+
except:
37+
print('%s %i/%i' % (category, i, n_urls), 'no image')
38+
39+
40+
def load_img(path):
41+
img = skimage.io.imread(path)
42+
img = img / 255.0
43+
# print "Original Image Shape: ", img.shape
44+
# we crop image from center
45+
short_edge = min(img.shape[:2])
46+
yy = int((img.shape[0] - short_edge) / 2)
47+
xx = int((img.shape[1] - short_edge) / 2)
48+
crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
49+
# resize to 224, 224
50+
resized_img = skimage.transform.resize(crop_img, (224, 224))[None, :, :, :] # shape [1, 224, 224, 3]
51+
return resized_img
52+
53+
54+
def load_data():
55+
imgs = {'tiger': [], 'kittycat': []}
56+
for k in imgs.keys():
57+
dir = './data/' + k
58+
for file in os.listdir(dir):
59+
if not file.lower().endswith('.jpg'):
60+
continue
61+
try:
62+
resized_img = load_img(os.path.join(dir, file))
63+
except OSError:
64+
continue
65+
imgs[k].append(resized_img) # [1, height, width, depth] * n
66+
if len(imgs[k]) == 400: # only use 400 imgs to reduce my memory load
67+
break
68+
# fake length data for tiger and cat
69+
tigers_y = np.maximum(20, np.random.randn(len(imgs['tiger']), 1) * 30 + 100)
70+
cat_y = np.maximum(10, np.random.randn(len(imgs['kittycat']), 1) * 8 + 40)
71+
return imgs['tiger'], imgs['kittycat'], tigers_y, cat_y
72+
73+
74+
class Vgg16:
75+
vgg_mean = [103.939, 116.779, 123.68]
76+
77+
def __init__(self, vgg16_npy_path=None, restore_from=None):
78+
# pre-trained parameters
79+
try:
80+
self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
81+
except FileNotFoundError:
82+
print('Please download VGG16 parameters at here https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM')
83+
84+
self.tfx = tf.placeholder(tf.float32, [None, 224, 224, 3])
85+
self.tfy = tf.placeholder(tf.float32, [None, 1])
86+
87+
# Convert RGB to BGR
88+
red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=self.tfx * 255.0)
89+
bgr = tf.concat(axis=3, values=[
90+
blue - self.vgg_mean[0],
91+
green - self.vgg_mean[1],
92+
red - self.vgg_mean[2],
93+
])
94+
95+
# pre-trained VGG layers are fixed in fine-tune
96+
conv1_1 = self.conv_layer(bgr, "conv1_1")
97+
conv1_2 = self.conv_layer(conv1_1, "conv1_2")
98+
pool1 = self.max_pool(conv1_2, 'pool1')
99+
100+
conv2_1 = self.conv_layer(pool1, "conv2_1")
101+
conv2_2 = self.conv_layer(conv2_1, "conv2_2")
102+
pool2 = self.max_pool(conv2_2, 'pool2')
103+
104+
conv3_1 = self.conv_layer(pool2, "conv3_1")
105+
conv3_2 = self.conv_layer(conv3_1, "conv3_2")
106+
conv3_3 = self.conv_layer(conv3_2, "conv3_3")
107+
pool3 = self.max_pool(conv3_3, 'pool3')
108+
109+
conv4_1 = self.conv_layer(pool3, "conv4_1")
110+
conv4_2 = self.conv_layer(conv4_1, "conv4_2")
111+
conv4_3 = self.conv_layer(conv4_2, "conv4_3")
112+
pool4 = self.max_pool(conv4_3, 'pool4')
113+
114+
conv5_1 = self.conv_layer(pool4, "conv5_1")
115+
conv5_2 = self.conv_layer(conv5_1, "conv5_2")
116+
conv5_3 = self.conv_layer(conv5_2, "conv5_3")
117+
pool5 = self.max_pool(conv5_3, 'pool5')
118+
119+
# detach original VGG fc layers and
120+
# reconstruct your own fc layers serve for your own purpose
121+
self.flatten = tf.reshape(pool5, [-1, 7*7*512])
122+
self.fc6 = tf.layers.dense(self.flatten, 256, tf.nn.relu, name='fc6')
123+
self.out = tf.layers.dense(self.fc6, 1, name='out')
124+
125+
self.sess = tf.Session()
126+
if restore_from:
127+
saver = tf.train.Saver()
128+
saver.restore(self.sess, restore_from)
129+
else: # training graph
130+
self.loss = tf.losses.mean_squared_error(labels=self.tfy, predictions=self.out)
131+
self.train_op = tf.train.RMSPropOptimizer(0.001).minimize(self.loss)
132+
self.sess.run(tf.global_variables_initializer())
133+
134+
def max_pool(self, bottom, name):
135+
return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
136+
137+
def conv_layer(self, bottom, name):
138+
with tf.variable_scope(name): # filter is constant
139+
conv = tf.nn.conv2d(bottom, self.data_dict[name][0], [1, 1, 1, 1], padding='SAME')
140+
lout = tf.nn.relu(tf.nn.bias_add(conv, self.data_dict[name][1]))
141+
return lout
142+
143+
def train(self, x, y):
144+
loss, _ = self.sess.run([self.loss, self.train_op], {self.tfx: x, self.tfy: y})
145+
return loss
146+
147+
def predict(self, paths):
148+
fig, axs = plt.subplots(1, 2)
149+
for i, path in enumerate(paths):
150+
x = load_img(path)
151+
length = self.sess.run(self.out, {self.tfx: x})
152+
axs[i].imshow(x[0])
153+
axs[i].set_title('Len: %.1f cm' % length)
154+
axs[i].set_xticks(()); axs[i].set_yticks(())
155+
plt.show()
156+
157+
def save(self, path='./for_transfer_learning/model/transfer_learn'):
158+
saver = tf.train.Saver()
159+
saver.save(self.sess, path, write_meta_graph=False)
160+
161+
162+
def train():
163+
tigers_x, cats_x, tigers_y, cats_y = load_data()
164+
165+
# plot fake length distribution
166+
plt.hist(tigers_y, bins=20, label='Tigers')
167+
plt.hist(cats_y, bins=10, label='Cats')
168+
plt.legend()
169+
plt.xlabel('length')
170+
plt.show()
171+
172+
xs = np.concatenate(tigers_x + cats_x, axis=0)
173+
ys = np.concatenate((tigers_y, cats_y), axis=0)
174+
175+
vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy')
176+
print('Net built')
177+
for i in range(100):
178+
b_idx = np.random.randint(0, len(xs), 6)
179+
train_loss = vgg.train(xs[b_idx], ys[b_idx])
180+
print(i, 'train loss: ', train_loss)
181+
182+
vgg.save('./for_transfer_learning/model/transfer_learn') # save learned fc layers
183+
184+
185+
def eval():
186+
vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy',
187+
restore_from='./for_transfer_learning/model/transfer_learn')
188+
vgg.predict(
189+
['./for_transfer_learning/data/kittycat/000129037.jpg', './for_transfer_learning/data/tiger/391412.jpg'])
190+
191+
192+
if __name__ == '__main__':
193+
# download()
194+
# train()
195+
eval()
25.9 KB
Loading
37.1 KB
Loading

0 commit comments

Comments
 (0)