-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
96 lines (76 loc) · 2.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
import json
import random
import warnings
from argparse import Namespace
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def parse_filelist(filelist_path, split_char="|"):
with open(filelist_path, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split_char) for line in f]
return filepaths_and_text
def latest_checkpoint_path(dir_path, regex="grad_*.pt"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def load_checkpoint(logdir, model, num=None):
if num is None:
model_path = latest_checkpoint_path(logdir, regex="grad_*.pt")
else:
model_path = os.path.join(logdir, f"grad_{num}.pt")
print(f'Loading checkpoint {model_path}...')
model_dict = torch.load(model_path, map_location=lambda loc, storage: loc)
model.load_state_dict(model_dict, strict=False)
return model
class AttrDict(Namespace):
def __init__(self, dictionary: dict):
for key, value in dictionary.items():
value = AttrDict(value) if isinstance(value, dict) else value
setattr(self, key, value)
def __setattr__(self, key, value):
value = AttrDict(value) if isinstance(value, dict) else value
super().__setattr__(key, value)
def to_dict(self):
return vars(self)
def load_config(config_path: str) -> AttrDict:
with open(config_path) as f:
return AttrDict(json.load(f))
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_tensor(tensor):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def save_plot(tensor, savepath):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
return