This repository has been archived by the owner on Jan 18, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2b3714c
Showing
11 changed files
with
893 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# PSPNet for MICCAI Automatic Prostate Gleason Grading Challenge 2019 | ||
This is the code for MICCAI Automatic Prostate Gleason Grading Challenge 2019. Check [here](https://gleason2019.grand-challenge.org/Home/) and [here](https://bmiai.ubc.ca/research/miccai-automatic-prostate-gleason-grading-challenge-2019), we took the 1st place of task 1: pixel-level Gleason grade prediction and task 2: core-level Gleason score prediction ([leaderboard](https://gleason2019.grand-challenge.org/Results/)). | ||
|
||
Task 1 is regarded as a segmentation task, and we use PSPNet for this. And for task 2, we do not train a different network, but just produce the prediction from the prediction of task 1 according to the Gleason grading system. | ||
|
||
The train script is based on reference script from torchvision 0.4.0 with minor modification. So, you need to install the latest PyTorch and torchvision >= 0.4.0. Check [requirements.txt](requirements.txt) for all packages you need. | ||
|
||
This repo use [GluonCV-Torch](https://github.com/zhanghang1989/gluoncv-torch), thanks for Hang Zhang's outstanding work! | ||
|
||
## Preprocessing | ||
Each image is annotated in detail by several expert pathologists. So how to use this annotations is important. We use STAPLE to create final annotations used in training. Check the [preprocessing.py](preprocessing.py) script for detail. | ||
|
||
 | ||
|
||
## Training | ||
|
||
 | ||
|
||
To run the training, simply run `python train.py`, check `python train_gleason.py --help` for available args. | ||
|
||
## Inference | ||
To run the inference, simply run `python inference.py`, check `python inference.py --help` for available args. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env python3 | ||
import torch | ||
from torch.utils.data import Dataset | ||
import os.path as osp | ||
import os | ||
from PIL import Image | ||
|
||
|
||
class Gleason(Dataset): | ||
def __init__(self, imgdir, maskdir=None, train=True, val=False, | ||
test=False, transforms=None, transform=None, target_transform=None): | ||
super(Gleason, self).__init__() | ||
self.imgdir = imgdir | ||
self.maskdir = maskdir | ||
self.imglist = sorted(os.listdir(imgdir)) | ||
if not test: | ||
self.masklist = [item.replace('.jpg', '_classimg_nonconvex.png') for item in self.imglist] | ||
else: | ||
self.masklist = [] | ||
self.train = train | ||
self.val = val | ||
self.test = test | ||
self.transforms = transforms | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
def __len__(self): | ||
return len(self.imglist) | ||
|
||
def __getitem__(self, idx): | ||
image = Image.open(osp.join(self.imgdir, self.imglist[idx])) | ||
if not self.test: | ||
mask = Image.open(osp.join(self.maskdir, self.masklist[idx])) | ||
if self.transforms and not self.test: | ||
image, mask = self.transforms(image, mask) | ||
if self.transform: | ||
image = self.transform(image, mask) | ||
if self.target_transform and not self.test: | ||
mask = self.target_transform(mask) | ||
|
||
if self.test: | ||
return image | ||
else: | ||
return image, mask |
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
import argparse | ||
import csv | ||
import os | ||
import os.path as osp | ||
|
||
import gluoncvth as gcv | ||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from torch import nn | ||
from torchvision import transforms | ||
|
||
|
||
def getargs(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--inputdir', type=str, required=True, | ||
help="input data dir") | ||
parser.add_argument('--model', type=str, required=True, | ||
help="model file") | ||
parser.add_argument('--outputdir', type=str, default=None, | ||
help='output dir') | ||
parser.add_argument('--aux', action='store_true', | ||
help='use aux layer') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def get_gleason_grade(segmentation): | ||
segmentation = segmentation.flatten() | ||
u, count = np.unique(segmentation, return_counts=True) | ||
|
||
ind = np.argsort(count) | ||
if u.size == 1: | ||
primary = u[ind][-1] | ||
result = primary*2 | ||
elif u.size == 2: | ||
primary = u[ind][-1] | ||
secondary = u[ind][-2] | ||
result = primary + secondary | ||
else: | ||
primary = u[ind][-1] | ||
result = primary + u.max() | ||
|
||
return result | ||
|
||
|
||
if __name__ == '__main__': | ||
args = getargs() | ||
os.makedirs(osp.join(args.outputdir, 'task1'), exist_ok=True) | ||
os.makedirs(osp.join(args.outputdir, 'task2'), exist_ok=True) | ||
device= torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
num_classes = 6 | ||
model = gcv.models.get_psp_resnet101_ade(pretrained=True) | ||
model.auxlayer.conv5[-1] = nn.Conv2d(256, num_classes, kernel_size=1, stride=1) | ||
model.head.conv5[-1] = nn.Conv2d(512, num_classes, kernel_size=1, stride=1) | ||
|
||
model_data = torch.load(args.model, map_location='cpu') | ||
model.load_state_dict(model_data['model']) | ||
model = model.to(device) | ||
tf = transforms.Compose([ | ||
transforms.Resize(800), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
]) | ||
with torch.no_grad(): | ||
model.eval() | ||
grade = {} | ||
for imgfile in os.listdir(args.inputdir): | ||
data = Image.open(osp.join(args.inputdir, imgfile)) | ||
w = data.width | ||
h = data.height | ||
data = tf(data) | ||
data = data.to(device).unsqueeze(0) | ||
y, y_aux = model(data) | ||
# y_aux takes lower weight, or even 0 if you like | ||
y = y + 0.5 * y_aux | ||
y = y.argmax(dim=1).cpu().squeeze().numpy().astype(np.uint8) | ||
|
||
y[y == 2] = 6 | ||
result = Image.fromarray(y) | ||
result = transforms.Resize((h, w))(result) | ||
result.save(osp.join(args.outputdir, 'task1', imgfile)) | ||
grade[imgfile[:-4]] = get_gleason_grade(y) | ||
|
||
with open(osp.join(args.outputdir, 'task2', 'task2.csv'), 'w') as f: | ||
writer = csv.writer(f) | ||
for key in grade.keys(): | ||
writer.writerow([key, grade[key]]) | ||
print('Done') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# This script use STAPLE to generate ground truth from 6 experts | ||
# and convert the label to 0-5 | ||
# original label 0, 1, 3, 4, 5, 6 -> 0, 1, 3, 4, 5, 2 | ||
# so remember to convert it back after inference | ||
# | ||
import argparse | ||
import os | ||
import os.path as osp | ||
import multiprocessing | ||
from multiprocessing import Pool | ||
|
||
import SimpleITK as sitk | ||
|
||
|
||
def staple(item, inputdirs, outputdir, undecidedlabel): | ||
print("processing {}...".format(item)) | ||
|
||
imgs = [] | ||
for p in inputdirs: | ||
if osp.isfile(osp.join(p, item)): | ||
imgs.append(sitk.ReadImage(osp.join(p, item))) | ||
result = sitk.MultiLabelSTAPLE(imgs, 255) | ||
p1_data = sitk.GetArrayFromImage(imgs[0]) | ||
result_data = sitk.GetArrayFromImage(result) | ||
if undecidedlabel: | ||
result_data[result_data == 255] = undecidedlabel | ||
result_data[result_data == 6] = 2 | ||
else: | ||
result_data[result_data == 255] = p1_data[result_data == 255] | ||
result_data[result_data == 6] = 2 | ||
result = sitk.GetImageFromArray(result_data) | ||
result.CopyInformation(imgs[0]) | ||
sitk.WriteImage(result, osp.join(outputdir, item)) | ||
|
||
|
||
def getargs(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--inputdirs', type=str, nargs='+', default='/home/hubutui/Downloads/dataset/MICCAI2019/Gleason-2019', | ||
help='input dirs of all masks') | ||
parser.add_argument('--outputdir', type=str, default='/home/hubutui/Downloads/dataset/MICCAI2019/Gleason-2019/preprocessed/finalmask-255', | ||
help='output dir') | ||
parser.add_argument('--undecidedlabel', type=int, default=None, | ||
help="label value for undecided pixels, we simply use the one expert's label value" | ||
"in the order of arg inputdirs, " | ||
"you could also use 0-6 if needed," | ||
"or just use 255, and ignore this label at training") | ||
parser.add_argument('--pool-size', type=int, default=None, | ||
help='processes to run in parallel, default value is CPU count') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
args = getargs() | ||
if args.undecidedlabel not in [None, 0, 1, 2, 3, 4, 5, 6, 255]: | ||
raise ValueError("unexpected label value for undecided pixels".format(args.undecidedlabel)) | ||
os.makedirs(args.outputdir, exist_ok=True) | ||
maskfiles = [] | ||
for i in args.inputdirs: | ||
maskfiles = maskfiles + os.listdir(i) | ||
maskfiles = set(maskfiles) | ||
|
||
if args.pool_size: | ||
processes = args.pool_size | ||
else: | ||
processes = multiprocessing.cpu_count() | ||
with Pool(processes=processes) as pool: | ||
results = [pool.apply_async(staple, | ||
args=(maskfile, args.inputdirs, | ||
args.outputdir, args.undecidedlabel)) | ||
for maskfile in maskfiles] | ||
_ = [_.get() for _ in results] | ||
print("Done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
torch>=1.0 | ||
torchvision>=0.4.0 | ||
numpy | ||
SimpleITK | ||
gluoncv-torch |
Oops, something went wrong.