-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
41 lines (37 loc) · 1.61 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from model import ResNet18
import torch
from dataset import MyDataset
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
batch_size = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18().to(device)
# 加载预训练模型参数
pretrained_dict = torch.load('ResNet.pth', map_location=device)
model.load_state_dict(pretrained_dict)
print("Pretrained model loaded.") # 打印预训练模型加载信息
dataset_test = MyDataset('G:\\keep\\test', 'G:\\keep\\label.csv')
num_test = len(dataset_test)
test_loader = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=True, pin_memory=True,
drop_last=True)
rmse, mae = 0., 0.
step = 0
paths, labels, predicts = [], [], []
with torch.no_grad():
loader = tqdm(test_loader)
for img, label, path in loader:
paths += list(path)
labels += torch.flatten(label).tolist()
img, label = img.to(device), label.to(device).to(torch.float32)
predict = model(img)
predicts += torch.flatten(predict).tolist()
rmse += torch.sqrt(torch.pow(torch.abs(predict - label), 2).mean()).item()
mae += torch.abs(predict - label).mean().item()
step += 1
loader.set_description('step:{} {}/{}'.format(step, step * batch_size, num_test))
rmse /= step
mae /= step
print('Test\tMAE:{}\t RMSE:{}'.format(mae, rmse))
pd.DataFrame({'file': paths, 'label': labels, 'predict': predicts}).to_csv('testInfo.csv', index=False)
#----------------------------------------------------------------------------------------------------------------