Skip to content
This repository has been archived by the owner on Jan 18, 2022. It is now read-only.

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hubutui committed Oct 11, 2019
0 parents commit 2b3714c
Show file tree
Hide file tree
Showing 11 changed files with 893 additions and 0 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# PSPNet for MICCAI Automatic Prostate Gleason Grading Challenge 2019
This is the code for MICCAI Automatic Prostate Gleason Grading Challenge 2019. Check [here](https://gleason2019.grand-challenge.org/Home/) and [here](https://bmiai.ubc.ca/research/miccai-automatic-prostate-gleason-grading-challenge-2019), we took the 1st place of task 1: pixel-level Gleason grade prediction and task 2: core-level Gleason score prediction ([leaderboard](https://gleason2019.grand-challenge.org/Results/)).

Task 1 is regarded as a segmentation task, and we use PSPNet for this. And for task 2, we do not train a different network, but just produce the prediction from the prediction of task 1 according to the Gleason grading system.

The train script is based on reference script from torchvision 0.4.0 with minor modification. So, you need to install the latest PyTorch and torchvision >= 0.4.0. Check [requirements.txt](requirements.txt) for all packages you need.

This repo use [GluonCV-Torch](https://github.com/zhanghang1989/gluoncv-torch), thanks for Hang Zhang's outstanding work!

## Preprocessing
Each image is annotated in detail by several expert pathologists. So how to use this annotations is important. We use STAPLE to create final annotations used in training. Check the [preprocessing.py](preprocessing.py) script for detail.

![preprocessing](./images/preprocessing.png)

## Training

![PSPNet](./images/PSPNet.png)

To run the training, simply run `python train.py`, check `python train_gleason.py --help` for available args.

## Inference
To run the inference, simply run `python inference.py`, check `python inference.py --help` for available args.
44 changes: 44 additions & 0 deletions dataset/Gleason.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import torch
from torch.utils.data import Dataset
import os.path as osp
import os
from PIL import Image


class Gleason(Dataset):
def __init__(self, imgdir, maskdir=None, train=True, val=False,
test=False, transforms=None, transform=None, target_transform=None):
super(Gleason, self).__init__()
self.imgdir = imgdir
self.maskdir = maskdir
self.imglist = sorted(os.listdir(imgdir))
if not test:
self.masklist = [item.replace('.jpg', '_classimg_nonconvex.png') for item in self.imglist]
else:
self.masklist = []
self.train = train
self.val = val
self.test = test
self.transforms = transforms
self.transform = transform
self.target_transform = target_transform

def __len__(self):
return len(self.imglist)

def __getitem__(self, idx):
image = Image.open(osp.join(self.imgdir, self.imglist[idx]))
if not self.test:
mask = Image.open(osp.join(self.maskdir, self.masklist[idx]))
if self.transforms and not self.test:
image, mask = self.transforms(image, mask)
if self.transform:
image = self.transform(image, mask)
if self.target_transform and not self.test:
mask = self.target_transform(mask)

if self.test:
return image
else:
return image, mask
Empty file added dataset/__init__.py
Empty file.
Binary file added images/PSPNet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/preprocessing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
#
import argparse
import csv
import os
import os.path as osp

import gluoncvth as gcv
import numpy as np
import torch
from PIL import Image
from torch import nn
from torchvision import transforms


def getargs():
parser = argparse.ArgumentParser()
parser.add_argument('--inputdir', type=str, required=True,
help="input data dir")
parser.add_argument('--model', type=str, required=True,
help="model file")
parser.add_argument('--outputdir', type=str, default=None,
help='output dir')
parser.add_argument('--aux', action='store_true',
help='use aux layer')

return parser.parse_args()


def get_gleason_grade(segmentation):
segmentation = segmentation.flatten()
u, count = np.unique(segmentation, return_counts=True)

ind = np.argsort(count)
if u.size == 1:
primary = u[ind][-1]
result = primary*2
elif u.size == 2:
primary = u[ind][-1]
secondary = u[ind][-2]
result = primary + secondary
else:
primary = u[ind][-1]
result = primary + u.max()

return result


if __name__ == '__main__':
args = getargs()
os.makedirs(osp.join(args.outputdir, 'task1'), exist_ok=True)
os.makedirs(osp.join(args.outputdir, 'task2'), exist_ok=True)
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 6
model = gcv.models.get_psp_resnet101_ade(pretrained=True)
model.auxlayer.conv5[-1] = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
model.head.conv5[-1] = nn.Conv2d(512, num_classes, kernel_size=1, stride=1)

model_data = torch.load(args.model, map_location='cpu')
model.load_state_dict(model_data['model'])
model = model.to(device)
tf = transforms.Compose([
transforms.Resize(800),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
model.eval()
grade = {}
for imgfile in os.listdir(args.inputdir):
data = Image.open(osp.join(args.inputdir, imgfile))
w = data.width
h = data.height
data = tf(data)
data = data.to(device).unsqueeze(0)
y, y_aux = model(data)
# y_aux takes lower weight, or even 0 if you like
y = y + 0.5 * y_aux
y = y.argmax(dim=1).cpu().squeeze().numpy().astype(np.uint8)

y[y == 2] = 6
result = Image.fromarray(y)
result = transforms.Resize((h, w))(result)
result.save(osp.join(args.outputdir, 'task1', imgfile))
grade[imgfile[:-4]] = get_gleason_grade(y)

with open(osp.join(args.outputdir, 'task2', 'task2.csv'), 'w') as f:
writer = csv.writer(f)
for key in grade.keys():
writer.writerow([key, grade[key]])
print('Done')
75 changes: 75 additions & 0 deletions preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
#
# This script use STAPLE to generate ground truth from 6 experts
# and convert the label to 0-5
# original label 0, 1, 3, 4, 5, 6 -> 0, 1, 3, 4, 5, 2
# so remember to convert it back after inference
#
import argparse
import os
import os.path as osp
import multiprocessing
from multiprocessing import Pool

import SimpleITK as sitk


def staple(item, inputdirs, outputdir, undecidedlabel):
print("processing {}...".format(item))

imgs = []
for p in inputdirs:
if osp.isfile(osp.join(p, item)):
imgs.append(sitk.ReadImage(osp.join(p, item)))
result = sitk.MultiLabelSTAPLE(imgs, 255)
p1_data = sitk.GetArrayFromImage(imgs[0])
result_data = sitk.GetArrayFromImage(result)
if undecidedlabel:
result_data[result_data == 255] = undecidedlabel
result_data[result_data == 6] = 2
else:
result_data[result_data == 255] = p1_data[result_data == 255]
result_data[result_data == 6] = 2
result = sitk.GetImageFromArray(result_data)
result.CopyInformation(imgs[0])
sitk.WriteImage(result, osp.join(outputdir, item))


def getargs():
parser = argparse.ArgumentParser()
parser.add_argument('--inputdirs', type=str, nargs='+', default='/home/hubutui/Downloads/dataset/MICCAI2019/Gleason-2019',
help='input dirs of all masks')
parser.add_argument('--outputdir', type=str, default='/home/hubutui/Downloads/dataset/MICCAI2019/Gleason-2019/preprocessed/finalmask-255',
help='output dir')
parser.add_argument('--undecidedlabel', type=int, default=None,
help="label value for undecided pixels, we simply use the one expert's label value"
"in the order of arg inputdirs, "
"you could also use 0-6 if needed,"
"or just use 255, and ignore this label at training")
parser.add_argument('--pool-size', type=int, default=None,
help='processes to run in parallel, default value is CPU count')

return parser.parse_args()


if __name__ == '__main__':
args = getargs()
if args.undecidedlabel not in [None, 0, 1, 2, 3, 4, 5, 6, 255]:
raise ValueError("unexpected label value for undecided pixels".format(args.undecidedlabel))
os.makedirs(args.outputdir, exist_ok=True)
maskfiles = []
for i in args.inputdirs:
maskfiles = maskfiles + os.listdir(i)
maskfiles = set(maskfiles)

if args.pool_size:
processes = args.pool_size
else:
processes = multiprocessing.cpu_count()
with Pool(processes=processes) as pool:
results = [pool.apply_async(staple,
args=(maskfile, args.inputdirs,
args.outputdir, args.undecidedlabel))
for maskfile in maskfiles]
_ = [_.get() for _ in results]
print("Done")
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch>=1.0
torchvision>=0.4.0
numpy
SimpleITK
gluoncv-torch
Loading

0 comments on commit 2b3714c

Please sign in to comment.