-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsplit.py
32 lines (27 loc) · 896 Bytes
/
split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data import random_split
import json
mp_split= torch.tensor([78600, 4367, 4367],dtype=torch.long)
gen=torch.Generator().manual_seed(42)
train_mp, valid_mp, test_mp = random_split(
list(range(mp_split.sum())), mp_split, generator=gen
)
oqmd_split= torch.tensor([199686, 11094, 11094],dtype=torch.long)
gen=torch.Generator().manual_seed(42)
train_oqmd, valid_oqmd, test_oqmd = random_split(
list(range(oqmd_split.sum())), oqmd_split, generator=gen
)
split = {
"mp":{
"train":list(map(int,train_mp.indices)),
"valid":list(map(int,valid_mp.indices)),
"test":list(map(int,test_mp.indices)),
},
"oqmd":{
"train":list(map(int,train_oqmd.indices)),
"valid":list(map(int,valid_oqmd.indices)),
"test":list(map(int,test_oqmd.indices)),
},
}
with open("split.json","w") as fp:
json.dump(split,fp)