Skip to content

Commit 0f5497a

Browse files
committed
update
1 parent 946abac commit 0f5497a

2 files changed

Lines changed: 102 additions & 0 deletions

File tree

‎README.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ All methods mentioned below have their video and text tutorial in Chinese. Visit
4141
* [Dropout](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/501_dropout.py)
4242
* [Batch Normalization](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/502_batch_normalization.py)
4343
* [Visualize Gradient Descent](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/503_visualize_gradient_descent.py)
44+
* [Distributed training](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/504_distributed_training.py)
4445

4546
### [Regression](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/301_simple_regression.py)
4647

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Know more, visit my Python tutorial page: https://morvanzhou.github.io/
3+
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
4+
5+
Dependencies:
6+
tensorflow: 1.4.0
7+
"""
8+
9+
import tensorflow as tf
10+
import multiprocessing as mp
11+
import numpy as np
12+
import os, shutil
13+
14+
15+
TRAINING = True
16+
17+
# training data
18+
x = np.linspace(-1, 1, 100)[:, np.newaxis]
19+
noise = np.random.normal(0, 0.1, size=x.shape)
20+
y = np.power(x, 2) + noise
21+
22+
23+
def work(job_name, task_index, step, lock):
24+
# set work's ip:port, parameter server and worker are the same steps
25+
cluster = tf.train.ClusterSpec({
26+
"ps": ['localhost:2221', ],
27+
"worker": ['localhost:2222', 'localhost:2223', 'localhost:2224',]
28+
})
29+
server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
30+
31+
if job_name == 'ps':
32+
# join parameter server
33+
print('Start Parameter Server: ', task_index)
34+
server.join()
35+
else:
36+
print('Start Worker: ', task_index, 'pid: ', mp.current_process().pid)
37+
# worker job
38+
with tf.device(tf.train.replica_device_setter(
39+
worker_device="/job:worker/task:%d" % task_index,
40+
cluster=cluster)):
41+
# build network
42+
tf_x = tf.placeholder(tf.float32, x.shape)
43+
tf_y = tf.placeholder(tf.float32, y.shape)
44+
l1 = tf.layers.dense(tf_x, 10, tf.nn.relu)
45+
output = tf.layers.dense(l1, 1)
46+
loss = tf.losses.mean_squared_error(tf_y, output)
47+
global_step = tf.train.get_or_create_global_step()
48+
train_op = tf.train.GradientDescentOptimizer(
49+
learning_rate=0.001).minimize(loss, global_step=global_step)
50+
51+
# set training steps
52+
hooks = [tf.train.StopAtStepHook(last_step=100000)]
53+
54+
# get session
55+
with tf.train.MonitoredTrainingSession(master=server.target,
56+
is_chief=(task_index == 0),
57+
checkpoint_dir='./tmp',
58+
hooks=hooks) as mon_sess:
59+
print("Start Worker Session: ", task_index)
60+
while not mon_sess.should_stop():
61+
# train
62+
_, loss_ = mon_sess.run([train_op, loss], {tf_x: x, tf_y: y})
63+
with lock:
64+
step.value += 1
65+
if step.value % 500 == 0:
66+
print("Task: ", task_index, "| Step: ", step.value, "| Loss: ", loss_)
67+
print('Worker Done: ', task_index)
68+
69+
70+
def parallel_train():
71+
if os.path.exists('./tmp'):
72+
shutil.rmtree('./tmp')
73+
# use multiprocessing to create a local cluster with 2 parameter servers and 4 workers
74+
jobs = [('ps', 0), ('worker', 0), ('worker', 1), ('worker', 2)]
75+
step = mp.Value('i', 0)
76+
lock = mp.Lock()
77+
ps = [mp.Process(target=work, args=(j, i, step, lock), ) for j, i in jobs]
78+
[p.start() for p in ps]
79+
[p.join() for p in ps]
80+
81+
82+
def eval():
83+
tf_x = tf.placeholder(tf.float32, [None, 1])
84+
l1 = tf.layers.dense(tf_x, 10, tf.nn.relu)
85+
output = tf.layers.dense(l1, 1)
86+
saver = tf.train.Saver()
87+
sess = tf.Session()
88+
saver.restore(sess, tf.train.latest_checkpoint('./tmp'))
89+
result = sess.run(output, {tf_x: x})
90+
# plot
91+
import matplotlib.pyplot as plt
92+
plt.scatter(x.ravel(), y, c='b')
93+
plt.plot(x.ravel(), result.ravel(), c='r')
94+
plt.show()
95+
96+
97+
if __name__ == "__main__":
98+
if TRAINING:
99+
parallel_train()
100+
else:
101+
eval()

0 commit comments

Comments
 (0)