Skip to content

Commit 66093ec

Browse files
author
morvanzhou
committed
update
1 parent f11e151 commit 66093ec

8 files changed

Lines changed: 216 additions & 242 deletions

File tree

‎Attention.py‎

Lines changed: 0 additions & 107 deletions
This file was deleted.

‎attention.py‎

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
import numpy as np
4+
import utils
5+
import tensorflow_addons as tfa
6+
import pickle
7+
8+
9+
class Seq2Seq(keras.Model):
10+
def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, attention_layer_size, max_pred_len, start_token, end_token):
11+
super().__init__()
12+
self.units = units
13+
14+
# encoder
15+
self.enc_embeddings = keras.layers.Embedding(
16+
input_dim=enc_v_dim, output_dim=emb_dim, # [enc_n_vocab, emb_dim]
17+
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
18+
)
19+
self.encoder = keras.layers.LSTM(units=units, return_sequences=True, return_state=True)
20+
21+
# decoder
22+
self.dec_embeddings = keras.layers.Embedding(
23+
input_dim=dec_v_dim, output_dim=emb_dim, # [dec_n_vocab, emb_dim]
24+
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
25+
)
26+
self.attention = tfa.seq2seq.LuongAttention(units, memory=None, memory_sequence_length=None)
27+
self.decoder_cell = tfa.seq2seq.AttentionWrapper(
28+
cell=keras.layers.LSTMCell(units=units),
29+
attention_mechanism=self.attention,
30+
attention_layer_size=attention_layer_size,
31+
alignment_history=True, # for attention visualization
32+
)
33+
decoder_dense = keras.layers.Dense(dec_v_dim)
34+
35+
# train decoder
36+
self.decoder_train = tfa.seq2seq.BasicDecoder(
37+
cell=self.decoder_cell,
38+
sampler=tfa.seq2seq.sampler.TrainingSampler(), # sampler for train
39+
output_layer=decoder_dense
40+
)
41+
# predict decoder
42+
self.decoder_eval = tfa.seq2seq.BasicDecoder(
43+
cell=self.decoder_cell,
44+
sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(), # sampler for predict
45+
output_layer=decoder_dense
46+
)
47+
48+
self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
49+
self.opt = keras.optimizers.Adam(0.05, clipnorm=5.0)
50+
self.train_sampler = tfa.seq2seq.sampler.TrainingSampler()
51+
self.max_pred_len = max_pred_len
52+
self.start_token = start_token
53+
self.end_token = end_token
54+
55+
def encode(self, x):
56+
o = self.enc_embeddings(x)
57+
init_s = [tf.zeros((x.shape[0], self.units)), tf.zeros((x.shape[0], self.units))]
58+
o, h, c = self.encoder(o, initial_state=init_s)
59+
60+
# encoder output for attention to focus
61+
self.attention.setup_memory(o)
62+
# wrap state by attention wrapper
63+
s = self.decoder_cell.get_initial_state(batch_size=x.shape[0], dtype=tf.float32).clone(cell_state=[h, c])
64+
return s
65+
66+
def inference(self, x, return_align=False):
67+
s = self.encode(x)
68+
done, i, s = self.decoder_eval.initialize(
69+
self.dec_embeddings.variables[0],
70+
start_tokens=tf.fill([x.shape[0], ], self.start_token),
71+
end_token=self.end_token,
72+
initial_state=s,
73+
)
74+
pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32)
75+
for l in range(self.max_pred_len):
76+
o, s, i, done = self.decoder_eval.step(
77+
time=l, inputs=i, state=s, training=False)
78+
pred_id[:, l] = o.sample_id
79+
if return_align:
80+
return np.transpose(s.alignment_history.stack().numpy(), (1, 0, 2))
81+
else:
82+
s.alignment_history.mark_used() # gives warning otherwise
83+
return pred_id
84+
85+
def train_logits(self, x, y, seq_len):
86+
s = self.encode(x)
87+
dec_in = y[:, :-1] # ignore <EOS>
88+
dec_emb_in = self.dec_embeddings(dec_in)
89+
o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len)
90+
logits = o.rnn_output
91+
return logits
92+
93+
def step(self, x, y, seq_len):
94+
with tf.GradientTape() as tape:
95+
logits = self.train_logits(x, y, seq_len)
96+
dec_out = y[:, 1:] # ignore <GO>
97+
_loss = self.cross_entropy(dec_out, logits)
98+
grads = tape.gradient(_loss, self.trainable_variables)
99+
self.opt.apply_gradients(zip(grads, self.trainable_variables))
100+
return _loss.numpy()
101+
102+
103+
# get and process data
104+
data = utils.DateData(2000)
105+
print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3])
106+
print("vocabularies: ", data.vocab)
107+
print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]),
108+
"\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0]))
109+
110+
model = Seq2Seq(
111+
data.num_word, data.num_word, emb_dim=12, units=14, attention_layer_size=16,
112+
max_pred_len=11, start_token=data.start_token, end_token=data.end_token)
113+
114+
# training
115+
for t in range(1000):
116+
bx, by, decoder_len = data.sample(64)
117+
loss = model.step(bx, by, decoder_len)
118+
if t % 30 == 0:
119+
target = data.idx2str(by[0, 1:-1])
120+
pred = model.inference(bx[0:1])
121+
res = data.idx2str(pred[0])
122+
src = data.idx2str(bx[0])
123+
print(
124+
"t: ", t,
125+
"| loss: %.5f" % loss,
126+
"| input: ", src,
127+
"| target: ", target,
128+
"| inference: ", res,
129+
)
130+
131+
pkl_data = {"i2v": data.i2v, "x": data.x[:6], "y": data.y[:6], "align": model.inference(data.x[:6], return_align=True)}
132+
133+
with open("./visual_helper/attention_align.pkl", "wb") as f:
134+
pickle.dump(pkl_data, f)

0 commit comments

Comments
 (0)