-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path4. Testing Model SEQ2SEQ Model.py
42 lines (35 loc) · 1.38 KB
/
4. Testing Model SEQ2SEQ Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
########## PART 4 - TESTING THE SEQ2SEQ MODEL ##########
# Loading the weights and Running the session
checkpoint = "./chatbot_weights.ckpt"
session = tf.InteractiveSession()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(session, checkpoint)
# Converting the questions from strings to lists of encoding integers
def convert_string2int(question, word2int):
question = clean_text(question)
return [word2int.get(word, word2int['<OUT>']) for word in question.split()]
# Setting up the chat
while(True):
question = input("You: ")
if question == 'Goodbye':
break
question = convert_string2int(question, questionswords2int)
question = question + [questionswords2int['<PAD>']] * (25 - len(question))
fake_batch = np.zeros((batch_size, 25))
fake_batch[0] = question
predicted_answer = session.run(test_predictions, {inputs: fake_batch, keep_prob: 0.5})[0]
answer = ''
for i in np.argmax(predicted_answer, 1):
if answersints2word[i] == 'i':
token = ' I'
elif answersints2word[i] == '<EOS>':
token = '.'
elif answersints2word[i] == '<OUT>':
token = 'out'
else:
token = ' ' + answersints2word[i]
answer += token
if token == '.':
break
print('ChatBot: ' + answer)