|
| 1 | +import tensorflow as tf |
| 2 | +from tensorflow import keras |
| 3 | +import numpy as np |
| 4 | +import utils |
| 5 | +import tensorflow_addons as tfa |
| 6 | +import pickle |
| 7 | + |
| 8 | + |
| 9 | +class Seq2Seq(keras.Model): |
| 10 | + def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, attention_layer_size, max_pred_len, start_token, end_token): |
| 11 | + super().__init__() |
| 12 | + self.units = units |
| 13 | + |
| 14 | + # encoder |
| 15 | + self.enc_embeddings = keras.layers.Embedding( |
| 16 | + input_dim=enc_v_dim, output_dim=emb_dim, # [enc_n_vocab, emb_dim] |
| 17 | + embeddings_initializer=tf.initializers.RandomNormal(0., 0.1), |
| 18 | + ) |
| 19 | + self.encoder = keras.layers.LSTM(units=units, return_sequences=True, return_state=True) |
| 20 | + |
| 21 | + # decoder |
| 22 | + self.dec_embeddings = keras.layers.Embedding( |
| 23 | + input_dim=dec_v_dim, output_dim=emb_dim, # [dec_n_vocab, emb_dim] |
| 24 | + embeddings_initializer=tf.initializers.RandomNormal(0., 0.1), |
| 25 | + ) |
| 26 | + self.attention = tfa.seq2seq.LuongAttention(units, memory=None, memory_sequence_length=None) |
| 27 | + self.decoder_cell = tfa.seq2seq.AttentionWrapper( |
| 28 | + cell=keras.layers.LSTMCell(units=units), |
| 29 | + attention_mechanism=self.attention, |
| 30 | + attention_layer_size=attention_layer_size, |
| 31 | + alignment_history=True, # for attention visualization |
| 32 | + ) |
| 33 | + decoder_dense = keras.layers.Dense(dec_v_dim) |
| 34 | + |
| 35 | + # train decoder |
| 36 | + self.decoder_train = tfa.seq2seq.BasicDecoder( |
| 37 | + cell=self.decoder_cell, |
| 38 | + sampler=tfa.seq2seq.sampler.TrainingSampler(), # sampler for train |
| 39 | + output_layer=decoder_dense |
| 40 | + ) |
| 41 | + # predict decoder |
| 42 | + self.decoder_eval = tfa.seq2seq.BasicDecoder( |
| 43 | + cell=self.decoder_cell, |
| 44 | + sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(), # sampler for predict |
| 45 | + output_layer=decoder_dense |
| 46 | + ) |
| 47 | + |
| 48 | + self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True) |
| 49 | + self.opt = keras.optimizers.Adam(0.05, clipnorm=5.0) |
| 50 | + self.train_sampler = tfa.seq2seq.sampler.TrainingSampler() |
| 51 | + self.max_pred_len = max_pred_len |
| 52 | + self.start_token = start_token |
| 53 | + self.end_token = end_token |
| 54 | + |
| 55 | + def encode(self, x): |
| 56 | + o = self.enc_embeddings(x) |
| 57 | + init_s = [tf.zeros((x.shape[0], self.units)), tf.zeros((x.shape[0], self.units))] |
| 58 | + o, h, c = self.encoder(o, initial_state=init_s) |
| 59 | + |
| 60 | + # encoder output for attention to focus |
| 61 | + self.attention.setup_memory(o) |
| 62 | + # wrap state by attention wrapper |
| 63 | + s = self.decoder_cell.get_initial_state(batch_size=x.shape[0], dtype=tf.float32).clone(cell_state=[h, c]) |
| 64 | + return s |
| 65 | + |
| 66 | + def inference(self, x, return_align=False): |
| 67 | + s = self.encode(x) |
| 68 | + done, i, s = self.decoder_eval.initialize( |
| 69 | + self.dec_embeddings.variables[0], |
| 70 | + start_tokens=tf.fill([x.shape[0], ], self.start_token), |
| 71 | + end_token=self.end_token, |
| 72 | + initial_state=s, |
| 73 | + ) |
| 74 | + pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32) |
| 75 | + for l in range(self.max_pred_len): |
| 76 | + o, s, i, done = self.decoder_eval.step( |
| 77 | + time=l, inputs=i, state=s, training=False) |
| 78 | + pred_id[:, l] = o.sample_id |
| 79 | + if return_align: |
| 80 | + return np.transpose(s.alignment_history.stack().numpy(), (1, 0, 2)) |
| 81 | + else: |
| 82 | + s.alignment_history.mark_used() # gives warning otherwise |
| 83 | + return pred_id |
| 84 | + |
| 85 | + def train_logits(self, x, y, seq_len): |
| 86 | + s = self.encode(x) |
| 87 | + dec_in = y[:, :-1] # ignore <EOS> |
| 88 | + dec_emb_in = self.dec_embeddings(dec_in) |
| 89 | + o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len) |
| 90 | + logits = o.rnn_output |
| 91 | + return logits |
| 92 | + |
| 93 | + def step(self, x, y, seq_len): |
| 94 | + with tf.GradientTape() as tape: |
| 95 | + logits = self.train_logits(x, y, seq_len) |
| 96 | + dec_out = y[:, 1:] # ignore <GO> |
| 97 | + _loss = self.cross_entropy(dec_out, logits) |
| 98 | + grads = tape.gradient(_loss, self.trainable_variables) |
| 99 | + self.opt.apply_gradients(zip(grads, self.trainable_variables)) |
| 100 | + return _loss.numpy() |
| 101 | + |
| 102 | + |
| 103 | +# get and process data |
| 104 | +data = utils.DateData(2000) |
| 105 | +print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3]) |
| 106 | +print("vocabularies: ", data.vocab) |
| 107 | +print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]), |
| 108 | + "\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0])) |
| 109 | + |
| 110 | +model = Seq2Seq( |
| 111 | + data.num_word, data.num_word, emb_dim=12, units=14, attention_layer_size=16, |
| 112 | + max_pred_len=11, start_token=data.start_token, end_token=data.end_token) |
| 113 | + |
| 114 | +# training |
| 115 | +for t in range(1000): |
| 116 | + bx, by, decoder_len = data.sample(64) |
| 117 | + loss = model.step(bx, by, decoder_len) |
| 118 | + if t % 30 == 0: |
| 119 | + target = data.idx2str(by[0, 1:-1]) |
| 120 | + pred = model.inference(bx[0:1]) |
| 121 | + res = data.idx2str(pred[0]) |
| 122 | + src = data.idx2str(bx[0]) |
| 123 | + print( |
| 124 | + "t: ", t, |
| 125 | + "| loss: %.5f" % loss, |
| 126 | + "| input: ", src, |
| 127 | + "| target: ", target, |
| 128 | + "| inference: ", res, |
| 129 | + ) |
| 130 | + |
| 131 | +pkl_data = {"i2v": data.i2v, "x": data.x[:6], "y": data.y[:6], "align": model.inference(data.x[:6], return_align=True)} |
| 132 | + |
| 133 | +with open("./visual_helper/attention_align.pkl", "wb") as f: |
| 134 | + pickle.dump(pkl_data, f) |
0 commit comments