|
| 1 | +""" |
| 2 | +Know more, visit my Python tutorial page: https://morvanzhou.github.io/ |
| 3 | +My Youtube Channel: https://www.youtube.com/user/MorvanZhou |
| 4 | +
|
| 5 | +Dependencies: |
| 6 | +tensorflow: 1.4.0 |
| 7 | +""" |
| 8 | + |
| 9 | +import tensorflow as tf |
| 10 | +import multiprocessing as mp |
| 11 | +import numpy as np |
| 12 | +import os, shutil |
| 13 | + |
| 14 | + |
| 15 | +TRAINING = True |
| 16 | + |
| 17 | +# training data |
| 18 | +x = np.linspace(-1, 1, 100)[:, np.newaxis] |
| 19 | +noise = np.random.normal(0, 0.1, size=x.shape) |
| 20 | +y = np.power(x, 2) + noise |
| 21 | + |
| 22 | + |
| 23 | +def work(job_name, task_index, step, lock): |
| 24 | + # set work's ip:port, parameter server and worker are the same steps |
| 25 | + cluster = tf.train.ClusterSpec({ |
| 26 | + "ps": ['localhost:2221', ], |
| 27 | + "worker": ['localhost:2222', 'localhost:2223', 'localhost:2224',] |
| 28 | + }) |
| 29 | + server = tf.train.Server(cluster, job_name=job_name, task_index=task_index) |
| 30 | + |
| 31 | + if job_name == 'ps': |
| 32 | + # join parameter server |
| 33 | + print('Start Parameter Server: ', task_index) |
| 34 | + server.join() |
| 35 | + else: |
| 36 | + print('Start Worker: ', task_index, 'pid: ', mp.current_process().pid) |
| 37 | + # worker job |
| 38 | + with tf.device(tf.train.replica_device_setter( |
| 39 | + worker_device="/job:worker/task:%d" % task_index, |
| 40 | + cluster=cluster)): |
| 41 | + # build network |
| 42 | + tf_x = tf.placeholder(tf.float32, x.shape) |
| 43 | + tf_y = tf.placeholder(tf.float32, y.shape) |
| 44 | + l1 = tf.layers.dense(tf_x, 10, tf.nn.relu) |
| 45 | + output = tf.layers.dense(l1, 1) |
| 46 | + loss = tf.losses.mean_squared_error(tf_y, output) |
| 47 | + global_step = tf.train.get_or_create_global_step() |
| 48 | + train_op = tf.train.GradientDescentOptimizer( |
| 49 | + learning_rate=0.001).minimize(loss, global_step=global_step) |
| 50 | + |
| 51 | + # set training steps |
| 52 | + hooks = [tf.train.StopAtStepHook(last_step=100000)] |
| 53 | + |
| 54 | + # get session |
| 55 | + with tf.train.MonitoredTrainingSession(master=server.target, |
| 56 | + is_chief=(task_index == 0), |
| 57 | + checkpoint_dir='./tmp', |
| 58 | + hooks=hooks) as mon_sess: |
| 59 | + print("Start Worker Session: ", task_index) |
| 60 | + while not mon_sess.should_stop(): |
| 61 | + # train |
| 62 | + _, loss_ = mon_sess.run([train_op, loss], {tf_x: x, tf_y: y}) |
| 63 | + with lock: |
| 64 | + step.value += 1 |
| 65 | + if step.value % 500 == 0: |
| 66 | + print("Task: ", task_index, "| Step: ", step.value, "| Loss: ", loss_) |
| 67 | + print('Worker Done: ', task_index) |
| 68 | + |
| 69 | + |
| 70 | +def parallel_train(): |
| 71 | + if os.path.exists('./tmp'): |
| 72 | + shutil.rmtree('./tmp') |
| 73 | + # use multiprocessing to create a local cluster with 2 parameter servers and 4 workers |
| 74 | + jobs = [('ps', 0), ('worker', 0), ('worker', 1), ('worker', 2)] |
| 75 | + step = mp.Value('i', 0) |
| 76 | + lock = mp.Lock() |
| 77 | + ps = [mp.Process(target=work, args=(j, i, step, lock), ) for j, i in jobs] |
| 78 | + [p.start() for p in ps] |
| 79 | + [p.join() for p in ps] |
| 80 | + |
| 81 | + |
| 82 | +def eval(): |
| 83 | + tf_x = tf.placeholder(tf.float32, [None, 1]) |
| 84 | + l1 = tf.layers.dense(tf_x, 10, tf.nn.relu) |
| 85 | + output = tf.layers.dense(l1, 1) |
| 86 | + saver = tf.train.Saver() |
| 87 | + sess = tf.Session() |
| 88 | + saver.restore(sess, tf.train.latest_checkpoint('./tmp')) |
| 89 | + result = sess.run(output, {tf_x: x}) |
| 90 | + # plot |
| 91 | + import matplotlib.pyplot as plt |
| 92 | + plt.scatter(x.ravel(), y, c='b') |
| 93 | + plt.plot(x.ravel(), result.ravel(), c='r') |
| 94 | + plt.show() |
| 95 | + |
| 96 | + |
| 97 | +if __name__ == "__main__": |
| 98 | + if TRAINING: |
| 99 | + parallel_train() |
| 100 | + else: |
| 101 | + eval() |
0 commit comments