Skip to content

Commit

Permalink
Merge pull request #115 from dzenanz/fixValidationCrash
Browse files Browse the repository at this point in the history
Fix multiprocessing crash -> allow validation to be invoked standalone
  • Loading branch information
WuJunde authored Jun 17, 2024
2 parents ca28f30 + 24aff64 commit bad4500
Showing 1 changed file with 64 additions and 61 deletions.
125 changes: 64 additions & 61 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,66 +34,69 @@
import function


args = cfg.parse_args()
if args.dataset == 'refuge' or args.dataset == 'refuge2':
args.data_path = '../dataset'

GPUdevice = torch.device('cuda', args.gpu_device)

net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)

'''load pretrained model'''
assert args.weights != 0
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = 'cuda:{}'.format(args.gpu_device)
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']

state_dict = checkpoint['state_dict']
if args.distributed != 'none':
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# name = k[7:] # remove `module.`
name = 'module.' + k
new_state_dict[name] = v
# load params
else:
new_state_dict = state_dict

net.load_state_dict(new_state_dict)

# args.path_helper = checkpoint['path_helper']
# logger = create_logger(args.path_helper['log_path'])
# print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')

# args.path_helper = set_log_dir('logs', args.exp_name)
# logger = create_logger(args.path_helper['log_path'])
# logger.info(args)

args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)

'''segmentation data'''
nice_train_loader, nice_test_loader = get_dataloader(args)

'''begain valuation'''
best_acc = 0.0
best_tol = 1e4

if args.mod == 'sam_adpt':
net.eval()

if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
def main():
args = cfg.parse_args()
if args.dataset == 'refuge' or args.dataset == 'refuge2':
args.data_path = '../dataset'

GPUdevice = torch.device('cuda', args.gpu_device)

net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)

'''load pretrained model'''
assert args.weights != 0
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = 'cuda:{}'.format(args.gpu_device)
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']

state_dict = checkpoint['state_dict']
if args.distributed != 'none':
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# name = k[7:] # remove `module.`
name = 'module.' + k
new_state_dict[name] = v
# load params
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.')
new_state_dict = state_dict


net.load_state_dict(new_state_dict)

# args.path_helper = checkpoint['path_helper']
# logger = create_logger(args.path_helper['log_path'])
# print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')

# args.path_helper = set_log_dir('logs', args.exp_name)
# logger = create_logger(args.path_helper['log_path'])
# logger.info(args)

args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)

'''segmentation data'''
nice_train_loader, nice_test_loader = get_dataloader(args)

'''begain valuation'''
best_acc = 0.0
best_tol = 1e4

if args.mod == 'sam_adpt':
net.eval()

if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.')


if __name__ == '__main__':
main()

0 comments on commit bad4500

Please sign in to comment.