Skip to content

Commit

Permalink
Allow visualization of bounding boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
dzenanz committed Jun 20, 2024
1 parent bad4500 commit 2c6c569
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def hook(layer):

return hook

def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None):
def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None, boxes = None):

b,c,h,w = pred_masks.size()
dev = pred_masks.get_device()
Expand Down Expand Up @@ -997,6 +997,16 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N
gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5
gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1
gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4
if boxes is not None:
for i in range(b):
# the next line causes: ValueError: Tensor uint8 expected, got torch.float32
# imgs[i, :] = torchvision.utils.draw_bounding_boxes(imgs[i, :], boxes[i])
# until TorchVision 0.19 is released (paired with Pytorch 2.4), apply this workaround:
img255 = (imgs[i] * 255).byte()
img255 = torchvision.utils.draw_bounding_boxes(img255, boxes[i].reshape(-1, 4), colors="red")
img01 = img255 / 255
# torchvision.utils.save_image(img01, save_path + "_boxes.png")
imgs[i, :] = img01
tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:])
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
compose = torch.cat(tup,0)
Expand Down

0 comments on commit 2c6c569

Please sign in to comment.