-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathevaluate_imagenet.py
122 lines (106 loc) · 3.74 KB
/
evaluate_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import argparse
import sys
import os
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from torchvision.datasets import ImageNet
from torchvision import transforms
# mister_ed
from recoloradv.mister_ed import loss_functions as lf
from recoloradv.mister_ed import adversarial_training as advtrain
from recoloradv.mister_ed import adversarial_perturbations as ap
from recoloradv.mister_ed import adversarial_attacks as aa
from recoloradv.mister_ed import spatial_transformers as st
from recoloradv.mister_ed.utils import pytorch_utils as utils
# ReColorAdv
from recoloradv import perturbations as pt
from recoloradv import color_transformers as ct
from recoloradv import color_spaces as cs
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Evaluate a ResNet-50 trained on Imagenet '
'against ReColorAdv'
)
parser.add_argument('--imagenet_path', type=str, required=True,
help='path to ImageNet dataset')
parser.add_argument('--batch_size', type=int, default=100,
help='number of examples/minibatch')
parser.add_argument('--num_batches', type=int, required=False,
help='number of batches (default entire dataset)')
args = parser.parse_args()
model = resnet50(pretrained=True, progress=True)
normalizer = utils.DifferentiableNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(
args.imagenet_path,
split='val',
transform=transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
]),
)
val_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
)
model.eval()
if torch.cuda.is_available():
model.cuda()
cw_loss = lf.CWLossF6(model, normalizer, kappa=float('inf'))
perturbation_loss = lf.PerturbationNormLoss(lp=2)
adv_loss = lf.RegularizedLoss(
{'cw': cw_loss, 'pert': perturbation_loss},
{'cw': 1.0, 'pert': 0.05},
negate=True,
)
pgd_attack = aa.PGD(
model,
normalizer,
ap.ThreatModel(pt.ReColorAdv, {
'xform_class': ct.FullSpatial,
'cspace': cs.CIELUVColorSpace(),
'lp_style': 'inf',
'lp_bound': 0.06,
'xform_params': {
'resolution_x': 16,
'resolution_y': 32,
'resolution_z': 32,
},
'use_smooth_loss': True,
}),
adv_loss,
)
batches_correct = []
for batch_index, (inputs, labels) in enumerate(val_loader):
if (
args.num_batches is not None and
batch_index >= args.num_batches
):
break
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
adv_inputs = pgd_attack.attack(
inputs,
labels,
optimizer=optim.Adam,
optimizer_kwargs={'lr': 0.001},
signed=False,
verbose=False,
num_iterations=(100, 300),
).adversarial_tensors()
with torch.no_grad():
adv_logits = model(normalizer(adv_inputs))
batch_correct = (adv_logits.argmax(1) == labels).detach()
batch_accuracy = batch_correct.float().mean().item()
print(f'BATCH {batch_index:05d}',
f'accuracy = {batch_accuracy * 100:.1f}',
sep='\t')
batches_correct.append(batch_correct)
accuracy = torch.cat(batches_correct).float().mean().item()
print('OVERALL ',
f'accuracy = {accuracy * 100:.1f}',
sep='\t')