Skip to content

Commit

Permalink
convert tensors from cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
diego-escobedo committed May 15, 2022
1 parent 2dd1f62 commit 0820eb1
Show file tree
Hide file tree
Showing 6 changed files with 4,843 additions and 6 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down
11 changes: 6 additions & 5 deletions TrajectoryNet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
import os
import matplotlib
# matplotlib.use("Agg")
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np
import time
Expand Down Expand Up @@ -43,8 +45,7 @@
import dataset
from parse import parser

# matplotlib.use("Agg")
matplotlib.use('TkAgg')




Expand Down Expand Up @@ -283,7 +284,7 @@ def visualize(device, args, model, itr):
gt_samples = args.data.get_data()[idx]
gt_samples = torch.from_numpy(gt_samples).type(torch.float32).to(device)
d[i] = gt_samples
ax.hist2d(gt_samples[:, 0].numpy(), gt_samples[:, 1].numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.hist2d(gt_samples[:, 0].cpu().numpy(), gt_samples[:, 1].cpu().numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.invert_yaxis()
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
Expand All @@ -295,7 +296,7 @@ def visualize(device, args, model, itr):
ax = plt.subplot(nrows, ncols, ncols+i+1, aspect="equal")
if i == 0:
gt_samples = d[i]
ax.hist2d(gt_samples[:, 0].numpy(), gt_samples[:, 1].numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.hist2d(gt_samples[:, 0].cpu().numpy(), gt_samples[:, 1].cpu().numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.invert_yaxis()
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
Expand All @@ -305,7 +306,7 @@ def visualize(device, args, model, itr):
integration_times = torch.tensor([itp - args.time_scale, itp])
integration_times = integration_times.type(torch.float32).to(device)
advected_samples, = model(gt_samples, integration_times=integration_times) #add comma cuz unpacking tuple
ax.hist2d(advected_samples[:, 0].numpy(), advected_samples[:, 1].numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.hist2d(advected_samples[:, 0].cpu().numpy(), advected_samples[:, 1].cpu().numpy(), range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
ax.invert_yaxis()
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
Expand Down
Binary file added results/diego_newloss/checkpt.pth
Binary file not shown.
Binary file added results/diego_newloss/emds_v2.npy
Binary file not shown.
Loading

0 comments on commit 0820eb1

Please sign in to comment.