-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalidation_automation_paralel.py
67 lines (56 loc) · 1.93 KB
/
validation_automation_paralel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pickle
import utils.training_functions as tf
import utils.data_exploration as de
from utils.relay_list import signals
from tqdm import tqdm
import torch
import pandas as pd
import numpy as np
import concurrent.futures
from tqdm import tqdm
import os
# LSTM parameters
hidden_dim = 20
n_signals = 3
N = 64
# _batch_size => m in figure 1.
train_batch_size = 64
dev_batch_size = 16
test_batch_size = 16
# Classification type (binary)
tagset_size = 1
# Set
device = "cuda" if torch.cuda.is_available() else "cpu"
model = tf.FaultDetector(n_signals, hidden_dim, tagset_size).to(device)
with open("./datasets/dataset_validation.pkl", "rb") as f:
dataset_params = pickle.load(f)
def relay_validation(
i: int, dataset_params=dataset_params, model=model, signals=signals
):
print(f"ejecutando relé R{i+1}")
relay_number = i + 1
model.load_state_dict(torch.load(f"./models/automation/R{i+1}_currents.pth"))
signal_names = signals[(relay_number - 1) * 3 : (relay_number - 1) * 3 + 3]
dataset_validation = tf.ExistingDataset(dataset_params, signal_names)
dataset_df, conf_matrix = de.dataframe_creation(dataset_validation, model)
dataset_df.to_parquet(f"parquet_data/automation/R{relay_number}_df.parquet")
conf_matrix_df = pd.DataFrame(
conf_matrix_total.cpu(), columns=["TP", "FP", "TF", "FN", "TP + FN"]
)
conf_matrix_df.to_parquet(f"parquet_data/automation/R{relay_number}_CM_df.parquet")
np.save(
f"parquet_data/automation/R{relay_number}_10_CM.npy",
conf_matrix,
)
if __name__ == "__main__":
cores = os.cpu_count() - 1
# for i in tqdm(range(18)):
with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor:
# Use tqdm to wrap the map for the progress bar
list(
tqdm(
executor.map(relay_validation, range(18)),
total=len(range(18)),
desc="Processing Files",
)
)