diff --git a/utils.py b/utils.py index ca13d4f1..1d7f3a9c 100644 --- a/utils.py +++ b/utils.py @@ -951,7 +951,7 @@ def hook(layer): return hook -def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None): +def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None, boxes = None): b,c,h,w = pred_masks.size() dev = pred_masks.get_device() @@ -997,6 +997,16 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5 gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1 gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4 + if boxes is not None: + for i in range(b): + # the next line causes: ValueError: Tensor uint8 expected, got torch.float32 + # imgs[i, :] = torchvision.utils.draw_bounding_boxes(imgs[i, :], boxes[i]) + # until TorchVision 0.19 is released (paired with Pytorch 2.4), apply this workaround: + img255 = (imgs[i] * 255).byte() + img255 = torchvision.utils.draw_bounding_boxes(img255, boxes[i].reshape(-1, 4), colors="red") + img01 = img255 / 255 + # torchvision.utils.save_image(img01, save_path + "_boxes.png") + imgs[i, :] = img01 tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:]) # compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) compose = torch.cat(tup,0)