Skip to content

Commit

Permalink
fixed and added support for cylindric
Browse files Browse the repository at this point in the history
  • Loading branch information
stonescenter committed Feb 17, 2020
1 parent 94dbe35 commit 3bb42c5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3,778 deletions.
4 changes: 2 additions & 2 deletions config-cnn.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
"filename": "./dataset/phi025-025_eta025-025_filtered.csv",
"train_split": 0.70,
"normalise": true,
"cilyndrical": true
"cilyndrical": false
},
"training": {
"epochs": 20,
"batch_size": 32,
"save_model": true,
"load_model": false,
"load_model": true,
"use_gpu": true
},
"model": {
Expand Down
6 changes: 3 additions & 3 deletions core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import datetime as dt

import tensorflow as tf
from tensorflow import set_random_seed

import keras.backend as K
from keras.backend.tensorflow_backend import set_session
Expand Down Expand Up @@ -40,8 +39,9 @@ def __init__(self, configs):
config=tf.ConfigProto(log_device_placement=True)
sess = tf.Session(config=config)
set_session(sess)

set_random_seed(42)

#set_random_seed(42)
tf.compat.v1.set_random_seed(0)

def load_model(self):
if self.exist_model(self.save_fnameh5):
Expand Down
22 changes: 18 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,30 @@ def main():

#Save data to plot
X, y = data.prepare_training_data(FeatureType.Positions, normalise=False,
cilyndrical=True)
cilyndrical=cilyndrical)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=1-split, random_state=42)

y_pred = pd.DataFrame(y_predicted_orig)
y_true = pd.DataFrame(y_test_orig)

y_true.to_csv(os.path.join(output_path, 'y_true.csv'), header=False, index=False)
y_pred.to_csv(os.path.join(output_path, 'y_pred.csv'), header=False, index=False)
X_test.to_csv(os.path.join(output_path, 'x_test.csv'), header=False, index=False)
if cilyndrical:

y_true.to_csv(os.path.join(output_path, 'y_true_%s_cylin.csv' % configs['model']['name']),
header=False, index=False)
y_pred.to_csv(os.path.join(output_path, 'y_pred_%s_cylin.csv' % configs['model']['name']),
header=False, index=False)
X_test.to_csv(os.path.join(output_path, 'x_test_%s_cylin.csv' % configs['model']['name']),
header=False, index=False)
else:

y_true.to_csv(os.path.join(output_path, 'y_true_%s_xyz.csv' % configs['model']['name']),
header=False, index=False)
y_pred.to_csv(os.path.join(output_path, 'y_pred_%s_xyz.csv' % configs['model']['name']),
header=False, index=False)
X_test.to_csv(os.path.join(output_path, 'x_test_%s_xyz.csv' % configs['model']['name']),
header=False, index=False)

print('[Output] Results saved at %', output_path)

if __name__=='__main__':
Expand Down
3,787 changes: 18 additions & 3,769 deletions notebooks/plot_prediction.ipynb

Large diffs are not rendered by default.

0 comments on commit 3bb42c5

Please sign in to comment.