-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_dqn.py
59 lines (45 loc) · 1.36 KB
/
main_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from flask import Flask, jsonify, render_template
from flask_cors import CORS
from flask import request
import utils
import numpy as np
from Tetris import Tetris
import torch
app = Flask(__name__)
app._static_folder = "./templates/static"
CORS(app, resources=r'/*')
GRID_HEIGHT = 20
GRID_WIDTH = 10
torch.manual_seed(42)
model = torch.load('{}tetris_model_4600_epochs'.format('trained_models/'))
model.eval()
env = Tetris()
env.reset()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/tetris/test')
def hello():
return 'Hello World!'
@app.route('/tetris/next', methods=['POST'])
def get_next_step():
params = request.get_json()
board = params[0]
st = np.zeros((GRID_HEIGHT, GRID_WIDTH))
for i in range(len(board)):
for j in range(len(board[0])):
if board[i][j] != 0:
st[i][j] = 1
shape = params[1]["flag"]
env.board = st
next_states = []
for action_index in range(len(utils.ACTIONS[shape])):
next_states.append(torch.FloatTensor(env.get_features(env.get_next_state(shape, action_index))))
next_states = torch.stack(next_states)
predictions = model(next_states)[:, 0]
chosen_index = torch.argmax(predictions).item()
# print(shape, st)
action = utils.ACTIONS[shape][chosen_index]
return jsonify(action)
if __name__ == '__main__':
app.run()