-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsaved_func_predict_seq2seq.py
31 lines (30 loc) · 1.53 KB
/
saved_func_predict_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from saved_func_load_data_nmt import truncate_pad
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps, save_attention_weights = False):
"""Predict for sequence to sequence."""
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
src_vocab['<eos>']]
enc_valid_len = tf.constant([len(src_tokens)])
src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
# Add the batch axis
enc_X = tf.expand_dims(src_tokens, axis = 0)
enc_outputs = net.encoder(enc_X, enc_valid_len, training = False)
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
# Add the batch axis
dec_X = tf.expand_dims(tf.constant([tgt_vocab['<bos>']]), axis = 0)
output_seq, attention_weight_seq = [], []
for _ in range(num_steps):
Y, dec_state = net.decoder(dec_X, dec_state, training = False)
# We use the token with the highest prediction likelihood as the input
# of the decoder at the next time step
dec_X = tf.argmax(Y, axis = 2)
pred = tf.squeeze(dec_X, axis = 0)
# Save attention weights
if save_attention_weights:
attention_weight_seq.append(net.decoder.attention_weights)
# Once the end-of-sequence token is predicted, the generation of the
# output sequence is complete
if pred == tgt_vocab['<eos>']:
break
output_seq.append(pred.numpy())
return ' '.join(tgt_vocab.to_tokens(tf.reshape(output_seq, shape = -1).numpy().tolist())), attention_weight_seq