forked from Thaonguyen3095/affordance-language
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet.py
57 lines (50 loc) · 1.95 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import csv
import json
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torchvision import transforms
'''
Use the average pooling layer of a pretrained ResNet model to
generate object-embedding pairs corresponding to each given ImageNet image
'''
#-----LOAD MODEL-----#
#model = models.resnet50(pretrained=True)
model = models.resnet101(pretrained=True)
#model = models.resnet152(pretrained=True)
# Access average pooling layer in network
model_avgpool = nn.Sequential(*list(model.children())[:-1])
model_avgpool.eval()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#-----USE MODEL-----#
labels = {}
with open('data/object-embedding.csv', 'w', newline='') as out:
#below file is generated by imagenet.py
with open('data/image_label.json') as data:
csvwriter = csv.writer(out, delimiter=',')
for line in data:
image_label = json.loads(line)
#substitute with your own path to the ILSVRC2012 validation image set
dir = '/home/thao/Downloads/ILSVRC2012_img_val'
for f in os.listdir(dir):
input_image = Image.open(os.path.join(dir, f))
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model_avgpool.to('cuda')
with torch.no_grad():
try:
output = model_avgpool(input_batch)
except:
print(os.path.join(dir, f))
output = torch.flatten(output, 1)
csvwriter.writerow((image_label[f], output[0].tolist(), f))