-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
92 lines (78 loc) · 2.87 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
import utils
from torchvision.transforms import ToTensor
class BananasDataset(torch.utils.data.Dataset):
"""
Banana detection dataset
It is from https://www.bilibili.com/video/BV1Lh411Y7LX?p=3, which is made by Li Mu and his mates.
you can download it from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
"""
def __init__(self, dataset_path, S=7, B=2, C=1, transform=None):
self.S = S
self.B = B
self.C = C
self.transform = transform
self.img_dir = os.path.join(dataset_path, 'images')
self.labels = pd.read_csv(os.path.join(dataset_path, 'label.csv'))
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
# an image have a banana
class_label = int(self.labels.iloc[index]['label'])
# [xmin, ymin, xmax, ymax]
box = self.labels.iloc[index, 2:].values.tolist()
img_path = os.path.join(self.img_dir, self.labels.iloc[index]['img_name'])
image = Image.open(img_path)
# convert to [x, y, width, height]
box = utils.box_corner_to_center(box)
# The size of the image is 256x256
box = torch.tensor(box) / 256.0
if self.transform:
image = self.transform(image)
label = torch.zeros((self.S, self.S, self.C + 5 * self.B))
x, y, width, height = box.tolist()
i, j = int(self.S * x), int(self.S * y)
x_cell, y_cell = self.S * x - i, self.S * y - j
width_cell, height_cell = (
width * self.S,
height * self.S
)
#
if label[j, i, 1] == 0:
label[j, i, 1] = 1 # obj confidence is 1
box_coordinates = torch.tensor(
[x_cell, y_cell, width_cell, height_cell]
)
# only assign coordinate to the first box
# the second box is zeros
label[j, i, 2:6] = box_coordinates
label[j, i, class_label] = 1
return image, label
def test_bananas_dataset_1():
trainDataset = BananasDataset('data/banana-detection/bananas_val/')
print('train data num: ', len(trainDataset))
# select the first group of train data
train_data = trainDataset[0]
print(train_data[1].shape)
print(train_data[0])
train_data[0].show()
def test_bananas_dataset_2():
train_dataset = BananasDataset('data/banana-detection/bananas_val/', transform=ToTensor())
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=16,
num_workers=4,
pin_memory=True,
shuffle=True,
drop_last=True
)
for (x, y) in train_loader:
print('image shape: ', x.shape)
print('label shape: ', y.shape)
break
if __name__ == '__main__':
test_bananas_dataset_1()