diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..62e43e1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+#common
+**/*.DS_Store
+**/*.ipynb_checkpoints/
+**/__pycache__
+out/
+
+#ZINC dataset
+data/molecules/*.pkl
+data/molecules/*.pickle
+data/molecules/*.zip
+data/molecules/zinc_full/*.pkl
+data/molecules/zinc_full/*.pickle
+
+
+#OGB
+dataset/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..44754d3
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Vijay Prakash Dwivedi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio and Xavier Bresson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e288bfb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,53 @@
+
+
+# Graph Neural Networks with
Learnable Structural and Positional Representations
+
+
+
+Source code for the paper "**Graph Neural Networks with Learnable Structural and Positional Representations**" by Vijay Prakash Dwivedi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio and Xavier Bresson.
+
+We propose a novel GNN architecture in which the structural and positional representations are decoupled, and are learnt separately to learn these two essential properties. The architecture, named **MPGNNs-LSPE** (MPGNNs with **L**earnable **S**tructural and **P**ositional **E**ncodings), is generic enough that it can be applied to any GNN model of interest which fits into the popular 'message-passing framework', including Transformers.
+
+![MPGNNs-LSPE](./docs/gnn-lspe.png)
+
+
+
+
+## 1. Repo installation
+
+[Follow these instructions](./docs/01_repo_installation.md) to install the repo and setup the environment.
+
+
+
+
+## 2. Download datasets
+
+[Proceed as follows](./docs/02_download_datasets.md) to download the benchmark datasets.
+
+
+
+
+## 3. Reproducibility
+
+[Use this page](./docs/03_run_codes.md) to run the codes and reproduce the published results.
+
+
+
+
+## 4. Reference
+
+TODO
+[ArXiv's paper](https://arxiv.org/pdf/2110.xxxxx.pdf)
+```
+@article{dwivedi2021graph,
+ title={Graph Neural Networks with Learnable Structural and Positional Representations},
+ author={Dwivedi, Vijay Prakash and Luu, Anh Tuan and Laurent, Thomas and Bengio, Yoshua and Bresson, Xavier},
+ journal={arXiv preprint arXiv:2110.xxxxx},
+ year={2021}
+}
+```
+
+
+
+
+
diff --git a/configs/GatedGCN_MOLPCBA_LSPE.json b/configs/GatedGCN_MOLPCBA_LSPE.json
new file mode 100644
index 0000000..0e79994
--- /dev/null
+++ b/configs/GatedGCN_MOLPCBA_LSPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLPCBA",
+
+ "out_dir": "out/GatedGCN_MOLPCBA_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-4,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 96
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 118,
+ "out_dim": 118,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.1,
+ "batch_norm": true,
+ "pos_enc_dim": 20,
+ "pe_init": "rand_walk",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_MOLPCBA_LapPE.json b/configs/GatedGCN_MOLPCBA_LapPE.json
new file mode 100644
index 0000000..70a9a0b
--- /dev/null
+++ b/configs/GatedGCN_MOLPCBA_LapPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLPCBA",
+
+ "out_dir": "out/GatedGCN_MOLPCBA_LapPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-4,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 154,
+ "out_dim": 154,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.1,
+ "batch_norm": true,
+ "pos_enc_dim": 3,
+ "pe_init": "lap_pe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_MOLPCBA_NoPE.json b/configs/GatedGCN_MOLPCBA_NoPE.json
new file mode 100644
index 0000000..141092e
--- /dev/null
+++ b/configs/GatedGCN_MOLPCBA_NoPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLPCBA",
+
+ "out_dir": "out/GatedGCN_MOLPCBA_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-4,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 154,
+ "out_dim": 154,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.1,
+ "batch_norm": true,
+ "pos_enc_dim": 3,
+ "pe_init": "no_pe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_MOLTOX21_LSPE.json b/configs/GatedGCN_MOLTOX21_LSPE.json
new file mode 100644
index 0000000..456bc36
--- /dev/null
+++ b/configs/GatedGCN_MOLTOX21_LSPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLTOX21",
+
+ "out_dir": "out/GatedGCN_MOLTOX21_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-5,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 118,
+ "out_dim": 118,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "batch_norm": true,
+ "pos_enc_dim": 16,
+ "pe_init": "rand_walk",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_MOLTOX21_LapPE.json b/configs/GatedGCN_MOLTOX21_LapPE.json
new file mode 100644
index 0000000..4b06e81
--- /dev/null
+++ b/configs/GatedGCN_MOLTOX21_LapPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLTOX21",
+
+ "out_dir": "out/GatedGCN_MOLTOX21_LapPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-5,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 154,
+ "out_dim": 154,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "batch_norm": true,
+ "pos_enc_dim": 3,
+ "pe_init": "lap_pe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_MOLTOX21_NoPE.json b/configs/GatedGCN_MOLTOX21_NoPE.json
new file mode 100644
index 0000000..e019c71
--- /dev/null
+++ b/configs/GatedGCN_MOLTOX21_NoPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "OGBG-MOLTOX21",
+
+ "out_dir": "out/GatedGCN_MOLTOX21_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-5,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 1,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 154,
+ "out_dim": 154,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "batch_norm": true,
+ "pos_enc_dim": 3,
+ "pe_init": "no_pe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-3,
+ "lambda_loss": 100
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_ZINC_LSPE.json b/configs/GatedGCN_ZINC_LSPE.json
new file mode 100644
index 0000000..4875bd6
--- /dev/null
+++ b/configs/GatedGCN_ZINC_LSPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GatedGCN_ZINC_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 12
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 59,
+ "out_dim": 59,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "batch_norm": true,
+ "pos_enc_dim": 20,
+ "pe_init": "rand_walk",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json b/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json
new file mode 100644
index 0000000..c5f6122
--- /dev/null
+++ b/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GatedGCN_ZINC_LSPE_withLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 12
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 59,
+ "out_dim": 59,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "batch_norm": true,
+ "pos_enc_dim": 20,
+ "pe_init": "rand_walk",
+ "use_lapeig_loss": true,
+ "alpha_loss": 1,
+ "lambda_loss": 1e-1
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_ZINC_LapPE.json b/configs/GatedGCN_ZINC_LapPE.json
new file mode 100644
index 0000000..4f82e16
--- /dev/null
+++ b/configs/GatedGCN_ZINC_LapPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GatedGCN_ZINC_LapPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 12
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 78,
+ "out_dim": 78,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "batch_norm": true,
+ "pe_init": "lap_pe",
+ "pos_enc_dim": 8,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1
+ }
+}
\ No newline at end of file
diff --git a/configs/GatedGCN_ZINC_NoPE.json b/configs/GatedGCN_ZINC_NoPE.json
new file mode 100644
index 0000000..cdd8cfc
--- /dev/null
+++ b/configs/GatedGCN_ZINC_NoPE.json
@@ -0,0 +1,41 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GatedGCN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GatedGCN_ZINC_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 12
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 78,
+ "out_dim": 78,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "mean",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "batch_norm": true,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1
+ }
+}
\ No newline at end of file
diff --git a/configs/GraphiT_MOLTOX21_LSPE.json b/configs/GraphiT_MOLTOX21_LSPE.json
new file mode 100644
index 0000000..5fc8be5
--- /dev/null
+++ b/configs/GraphiT_MOLTOX21_LSPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "GraphiT",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/GraphiT_MOLTOX21_LSPE_noLapEigLoss/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "full_graph": true,
+
+ "L": 10,
+ "hidden_dim": 64,
+ "out_dim": 64,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.5,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16,
+ "adaptive_edge_PE": true,
+ "p_steps": 16,
+ "gamma": 0.25
+ }
+}
\ No newline at end of file
diff --git a/configs/GraphiT_MOLTOX21_NoPE.json b/configs/GraphiT_MOLTOX21_NoPE.json
new file mode 100644
index 0000000..4bb2c58
--- /dev/null
+++ b/configs/GraphiT_MOLTOX21_NoPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "GraphiT",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/GraphiT_MOLTOX21_NoPE/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "full_graph": true,
+
+ "L": 10,
+ "hidden_dim": 88,
+ "out_dim": 88,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.5,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 12,
+ "adaptive_edge_PE": true,
+ "p_steps": 16,
+ "gamma": 0.25
+ }
+}
\ No newline at end of file
diff --git a/configs/GraphiT_ZINC_LSPE.json b/configs/GraphiT_ZINC_LSPE.json
new file mode 100644
index 0000000..cb0235d
--- /dev/null
+++ b/configs/GraphiT_ZINC_LSPE.json
@@ -0,0 +1,49 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GraphiT",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GraphiT_ZINC_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 32,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "full_graph": true,
+
+ "L": 10,
+ "hidden_dim": 48,
+ "out_dim": 48,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16,
+ "adaptive_edge_PE": true,
+ "p_steps": 16,
+ "gamma": 0.25
+ }
+}
\ No newline at end of file
diff --git a/configs/GraphiT_ZINC_NoPE.json b/configs/GraphiT_ZINC_NoPE.json
new file mode 100644
index 0000000..744ab5f
--- /dev/null
+++ b/configs/GraphiT_ZINC_NoPE.json
@@ -0,0 +1,49 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "GraphiT",
+ "dataset": "ZINC",
+
+ "out_dir": "out/GraphiT_ZINC_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 32,
+ "init_lr": 0.0003,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "full_graph": true,
+
+ "L": 10,
+ "hidden_dim": 64,
+ "out_dim": 64,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16,
+ "adaptive_edge_PE": true,
+ "p_steps": 16,
+ "gamma": 0.25
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_MOLPCBA_LSPE.json b/configs/PNA_MOLPCBA_LSPE.json
new file mode 100644
index 0000000..0043783
--- /dev/null
+++ b/configs/PNA_MOLPCBA_LSPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "PNA",
+ "dataset": "OGBG-MOLPCBA",
+ "out_dir":"out/PNA_MOLPCBA_LSPE_noLapEigLoss/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 512,
+ "init_lr": 0.0005,
+ "lr_reduce_factor": 0.8,
+ "lr_schedule_patience": 10,
+ "min_lr": 2e-5,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "L": 4,
+ "hidden_dim": 322,
+ "out_dim": 322,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "dropout_2": 0.1,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean sum max",
+ "scalers": "identity",
+ "gru": false,
+ "edge_dim": 16,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-2,
+ "lambda_loss": 1e-2,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_MOLPCBA_NoPE.json b/configs/PNA_MOLPCBA_NoPE.json
new file mode 100644
index 0000000..d58c8d2
--- /dev/null
+++ b/configs/PNA_MOLPCBA_NoPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "PNA",
+ "dataset": "OGBG-MOLPCBA",
+ "out_dir":"out/PNA_MOLPCBA_NoPE/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 512,
+ "init_lr": 0.0005,
+ "lr_reduce_factor": 0.8,
+ "lr_schedule_patience": 4,
+ "min_lr": 2e-5,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "L": 4,
+ "hidden_dim": 510,
+ "out_dim": 510,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.2,
+ "dropout_2": 0.0,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean sum max",
+ "scalers": "identity",
+ "gru": false,
+ "edge_dim": 16,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-2,
+ "lambda_loss": 1e-2,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_MOLTOX21_LSPE.json b/configs/PNA_MOLTOX21_LSPE.json
new file mode 100644
index 0000000..8a7e9f9
--- /dev/null
+++ b/configs/PNA_MOLTOX21_LSPE.json
@@ -0,0 +1,48 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "PNA",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/PNA_MOLTOX21_LSPE_noLapEigLoss/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.0005,
+ "lr_reduce_factor": 0.8,
+ "lr_schedule_patience": 10,
+ "min_lr": 2e-5,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 140,
+ "out_dim": 140,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "dropout_2": 0.1,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean max min std",
+ "scalers": "identity amplification attenuation",
+ "gru": false,
+ "edge_dim": 50,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "lpe_variant": "native_lpe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-2,
+ "lambda_loss": 1e-2,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json b/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json
new file mode 100644
index 0000000..58ef016
--- /dev/null
+++ b/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json
@@ -0,0 +1,48 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "PNA",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/PNA_MOLTOX21_LSPE_withLapEigLoss/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.0005,
+ "lr_reduce_factor": 0.8,
+ "lr_schedule_patience": 10,
+ "min_lr": 2e-5,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 140,
+ "out_dim": 140,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "dropout_2": 0.1,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean max min std",
+ "scalers": "identity amplification attenuation",
+ "gru": false,
+ "edge_dim": 50,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "lpe_variant": "native_lpe",
+ "use_lapeig_loss": true,
+ "alpha_loss": 1e-1,
+ "lambda_loss": 100,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_MOLTOX21_NoPE.json b/configs/PNA_MOLTOX21_NoPE.json
new file mode 100644
index 0000000..8619000
--- /dev/null
+++ b/configs/PNA_MOLTOX21_NoPE.json
@@ -0,0 +1,48 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "PNA",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/PNA_MOLTOX21_NoPE/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 256,
+ "init_lr": 0.0005,
+ "lr_reduce_factor": 0.8,
+ "lr_schedule_patience": 10,
+ "min_lr": 2e-5,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "L": 8,
+ "hidden_dim": 206,
+ "out_dim": 206,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.4,
+ "dropout_2": 0.1,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean max min std",
+ "scalers": "identity amplification attenuation",
+ "gru": false,
+ "edge_dim": 50,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "lpe_variant": "native_lpe",
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-2,
+ "lambda_loss": 1e-2,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_ZINC_LSPE.json b/configs/PNA_ZINC_LSPE.json
new file mode 100644
index 0000000..9a4a674
--- /dev/null
+++ b/configs/PNA_ZINC_LSPE.json
@@ -0,0 +1,51 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "PNA",
+ "dataset": "ZINC",
+
+ "out_dir": "out/PNA_ZINC_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 55,
+ "out_dim": 55,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean max min std",
+ "scalers": "identity amplification attenuation",
+ "towers": 5,
+ "divide_input_first": true,
+ "divide_input_last": true,
+ "gru": false,
+ "edge_dim": 40,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/PNA_ZINC_NoPE.json b/configs/PNA_ZINC_NoPE.json
new file mode 100644
index 0000000..67ef1ab
--- /dev/null
+++ b/configs/PNA_ZINC_NoPE.json
@@ -0,0 +1,51 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "PNA",
+ "dataset": "ZINC",
+
+ "out_dir": "out/PNA_ZINC_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.001,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 3e-6,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "L": 16,
+ "hidden_dim": 70,
+ "out_dim": 70,
+ "residual": true,
+ "edge_feat": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "graph_norm": true,
+ "batch_norm": true,
+ "aggregators": "mean max min std",
+ "scalers": "identity amplification attenuation",
+ "towers": 5,
+ "divide_input_first": true,
+ "divide_input_last": true,
+ "gru": false,
+ "edge_dim": 40,
+ "pretrans_layers" : 1,
+ "posttrans_layers" : 1,
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/SAN_MOLTOX21_LSPE.json b/configs/SAN_MOLTOX21_LSPE.json
new file mode 100644
index 0000000..0c15005
--- /dev/null
+++ b/configs/SAN_MOLTOX21_LSPE.json
@@ -0,0 +1,45 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "SAN",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/SAN_MOLTOX21_LSPE_noLapEigLoss/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "full_graph": true,
+ "init_gamma": 0.1,
+
+ "L": 10,
+ "hidden_dim": 64,
+ "out_dim": 64,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.5,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 12
+ }
+}
\ No newline at end of file
diff --git a/configs/SAN_MOLTOX21_NoPE.json b/configs/SAN_MOLTOX21_NoPE.json
new file mode 100644
index 0000000..9922214
--- /dev/null
+++ b/configs/SAN_MOLTOX21_NoPE.json
@@ -0,0 +1,45 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+ "model": "SAN",
+ "dataset": "OGBG-MOLTOX21",
+ "out_dir":"out/SAN_MOLTOX21_NoPE/",
+
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 128,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 96
+ },
+ "net_params": {
+ "full_graph": true,
+ "init_gamma": 0.1,
+
+ "L": 10,
+ "hidden_dim": 88,
+ "out_dim": 88,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.5,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/SAN_ZINC_LSPE.json b/configs/SAN_ZINC_LSPE.json
new file mode 100644
index 0000000..9e8dd59
--- /dev/null
+++ b/configs/SAN_ZINC_LSPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "SAN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/SAN_ZINC_LSPE_noLapEigLoss/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 32,
+ "init_lr": 0.0007,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "full_graph": true,
+ "init_gamma": 0.1,
+
+ "L": 10,
+ "hidden_dim": 48,
+ "out_dim": 48,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "rand_walk",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/configs/SAN_ZINC_NoPE.json b/configs/SAN_ZINC_NoPE.json
new file mode 100644
index 0000000..2dd8619
--- /dev/null
+++ b/configs/SAN_ZINC_NoPE.json
@@ -0,0 +1,47 @@
+{
+ "gpu": {
+ "use": true,
+ "id": 0
+ },
+
+ "model": "SAN",
+ "dataset": "ZINC",
+
+ "out_dir": "out/SAN_ZINC_NoPE/",
+
+ "params": {
+ "seed": 41,
+ "epochs": 1000,
+ "batch_size": 32,
+ "init_lr": 0.0003,
+ "lr_reduce_factor": 0.5,
+ "lr_schedule_patience": 25,
+ "min_lr": 1e-6,
+ "weight_decay": 0.0,
+ "print_epoch_interval": 5,
+ "max_time": 48
+ },
+
+ "net_params": {
+ "full_graph": true,
+ "init_gamma": 0.1,
+
+ "L": 10,
+ "hidden_dim": 64,
+ "out_dim": 64,
+ "n_heads": 8,
+
+ "residual": true,
+ "readout": "sum",
+ "in_feat_dropout": 0.0,
+ "dropout": 0.0,
+ "layer_norm": false,
+ "batch_norm": true,
+
+ "use_lapeig_loss": false,
+ "alpha_loss": 1e-4,
+ "lambda_loss": 1000,
+ "pe_init": "no_pe",
+ "pos_enc_dim": 16
+ }
+}
\ No newline at end of file
diff --git a/data/data.py b/data/data.py
new file mode 100644
index 0000000..3871142
--- /dev/null
+++ b/data/data.py
@@ -0,0 +1,21 @@
+"""
+ File to load dataset based on user control from main file
+"""
+
+from data.molecules import MoleculeDataset
+from data.ogb_mol import OGBMOLDataset
+
+def LoadData(DATASET_NAME):
+ """
+ This function is called in the main.py file
+ returns:
+ ; dataset object
+ """
+
+ # handling for (ZINC) molecule dataset
+ if DATASET_NAME == 'ZINC' or DATASET_NAME == 'ZINC-full':
+ return MoleculeDataset(DATASET_NAME)
+
+ # handling for MOLPCBA and MOLTOX21 dataset
+ if DATASET_NAME in ['OGBG-MOLPCBA', 'OGBG-MOLTOX21']:
+ return OGBMOLDataset(DATASET_NAME)
\ No newline at end of file
diff --git a/data/molecules.py b/data/molecules.py
new file mode 100644
index 0000000..a80eb40
--- /dev/null
+++ b/data/molecules.py
@@ -0,0 +1,331 @@
+import torch
+import torch.nn.functional as F
+import pickle
+import torch.utils.data
+import time
+import os
+import numpy as np
+
+import csv
+
+import dgl
+
+from scipy import sparse as sp
+import numpy as np
+import networkx as nx
+
+# The dataset pickle and index files are in ./data/molecules/ dir
+# [.pickle and .index; for split 'train', 'val' and 'test']
+
+
+
+
+class MoleculeDGL(torch.utils.data.Dataset):
+ def __init__(self, data_dir, split, num_graphs=None):
+ self.data_dir = data_dir
+ self.split = split
+ self.num_graphs = num_graphs
+
+ with open(data_dir + "/%s.pickle" % self.split,"rb") as f:
+ self.data = pickle.load(f)
+
+ if self.num_graphs in [10000, 1000]:
+ # loading the sampled indices from file ./zinc_molecules/.index
+ with open(data_dir + "/%s.index" % self.split,"r") as f:
+ data_idx = [list(map(int, idx)) for idx in csv.reader(f)]
+ self.data = [ self.data[i] for i in data_idx[0] ]
+
+ assert len(self.data)==num_graphs, "Sample num_graphs again; available idx: train/val/test => 10k/1k/1k"
+
+ """
+ data is a list of Molecule dict objects with following attributes
+
+ molecule = data[idx]
+ ; molecule['num_atom'] : nb of atoms, an integer (N)
+ ; molecule['atom_type'] : tensor of size N, each element is an atom type, an integer between 0 and num_atom_type
+ ; molecule['bond_type'] : tensor of size N x N, each element is a bond type, an integer between 0 and num_bond_type
+ ; molecule['logP_SA_cycle_normalized'] : the chemical property to regress, a float variable
+ """
+
+ self.graph_lists = []
+ self.graph_labels = []
+ self.n_samples = len(self.data)
+ self._prepare()
+
+ def _prepare(self):
+ print("preparing %d graphs for the %s set..." % (self.num_graphs, self.split.upper()))
+
+ for molecule in self.data:
+ node_features = molecule['atom_type'].long()
+
+ adj = molecule['bond_type']
+ edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list
+
+ edge_idxs_in_adj = edge_list.split(1, dim=1)
+ edge_features = adj[edge_idxs_in_adj].reshape(-1).long()
+
+ # Create the DGL Graph
+ g = dgl.DGLGraph()
+ g.add_nodes(molecule['num_atom'])
+ g.ndata['feat'] = node_features
+
+ for src, dst in edge_list:
+ g.add_edges(src.item(), dst.item())
+ g.edata['feat'] = edge_features
+
+ self.graph_lists.append(g)
+ self.graph_labels.append(molecule['logP_SA_cycle_normalized'])
+
+ def __len__(self):
+ """Return the number of graphs in the dataset."""
+ return self.n_samples
+
+ def __getitem__(self, idx):
+ """
+ Get the idx^th sample.
+ Parameters
+ ---------
+ idx : int
+ The sample index.
+ Returns
+ -------
+ (dgl.DGLGraph, int)
+ DGLGraph with node feature stored in `feat` field
+ And its label.
+ """
+ return self.graph_lists[idx], self.graph_labels[idx]
+
+
+class MoleculeDatasetDGL(torch.utils.data.Dataset):
+ def __init__(self, name='Zinc'):
+ t0 = time.time()
+ self.name = name
+
+ self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well
+ self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well
+
+ data_dir='./data/molecules'
+
+ if self.name == 'ZINC-full':
+ data_dir='./data/molecules/zinc_full'
+ self.train = MoleculeDGL(data_dir, 'train', num_graphs=220011)
+ self.val = MoleculeDGL(data_dir, 'val', num_graphs=24445)
+ self.test = MoleculeDGL(data_dir, 'test', num_graphs=5000)
+ else:
+ self.train = MoleculeDGL(data_dir, 'train', num_graphs=10000)
+ self.val = MoleculeDGL(data_dir, 'val', num_graphs=1000)
+ self.test = MoleculeDGL(data_dir, 'test', num_graphs=1000)
+ print("Time taken: {:.4f}s".format(time.time()-t0))
+
+
+
+def add_eig_vec(g, pos_enc_dim):
+ """
+ Graph positional encoding v/ Laplacian eigenvectors
+ This func is for eigvec visualization, same code as positional_encoding() func,
+ but stores value in a diff key 'eigvec'
+ """
+
+ # Laplacian
+ A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(g.number_of_nodes()) - N * A * N
+
+ # Eigenvectors with numpy
+ EigVal, EigVec = np.linalg.eig(L.toarray())
+ idx = EigVal.argsort() # increasing order
+ EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
+ g.ndata['eigvec'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
+
+ # zero padding to the end if n < pos_enc_dim
+ n = g.number_of_nodes()
+ if n <= pos_enc_dim:
+ g.ndata['eigvec'] = F.pad(g.ndata['eigvec'], (0, pos_enc_dim - n + 1), value=float('0'))
+
+ return g
+
+
+def lap_positional_encoding(g, pos_enc_dim):
+ """
+ Graph positional encoding v/ Laplacian eigenvectors
+ """
+
+ # Laplacian
+ A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(g.number_of_nodes()) - N * A * N
+
+ # Eigenvectors with numpy
+ EigVal, EigVec = np.linalg.eig(L.toarray())
+ idx = EigVal.argsort() # increasing order
+ EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
+ g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
+ g.ndata['eigvec'] = g.ndata['pos_enc']
+
+ # # Eigenvectors with scipy
+ # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
+ # EigVec = EigVec[:, EigVal.argsort()] # increasing order
+ # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
+
+ return g
+
+
+def init_positional_encoding(g, pos_enc_dim, type_init):
+ """
+ Initializing positional encoding with RWPE
+ """
+
+ n = g.number_of_nodes()
+
+ if type_init == 'rand_walk':
+ # Geometric diffusion features with Random Walk
+ A = g.adjacency_matrix(scipy_fmt="csr")
+ Dinv = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -1.0, dtype=float) # D^-1
+ RW = A * Dinv
+ M = RW
+
+ # Iterate
+ nb_pos_enc = pos_enc_dim
+ PE = [torch.from_numpy(M.diagonal()).float()]
+ M_power = M
+ for _ in range(nb_pos_enc-1):
+ M_power = M_power * M
+ PE.append(torch.from_numpy(M_power.diagonal()).float())
+ PE = torch.stack(PE,dim=-1)
+ g.ndata['pos_enc'] = PE
+
+ return g
+
+
+def make_full_graph(g, adaptive_weighting=None):
+
+ full_g = dgl.from_networkx(nx.complete_graph(g.number_of_nodes()))
+
+ #Here we copy over the node feature data and laplace encodings
+ full_g.ndata['feat'] = g.ndata['feat']
+
+ try:
+ full_g.ndata['pos_enc'] = g.ndata['pos_enc']
+ except:
+ pass
+
+ try:
+ full_g.ndata['eigvec'] = g.ndata['eigvec']
+ except:
+ pass
+
+ #Populate edge features w/ 0s
+ full_g.edata['feat']=torch.zeros(full_g.number_of_edges(), dtype=torch.long)
+ full_g.edata['real']=torch.zeros(full_g.number_of_edges(), dtype=torch.long)
+
+ #Copy real edge data over
+ full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['feat'] = g.edata['feat']
+ full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['real'] = torch.ones(g.edata['feat'].shape[0], dtype=torch.long)
+
+
+ # This code section only apply for GraphiT --------------------------------------------
+ if adaptive_weighting is not None:
+ p_steps, gamma = adaptive_weighting
+
+ n = g.number_of_nodes()
+ A = g.adjacency_matrix(scipy_fmt="csr")
+
+ # Adaptive weighting k_ij for each edge
+ if p_steps == "qtr_num_nodes":
+ p_steps = int(0.25*n)
+ elif p_steps == "half_num_nodes":
+ p_steps = int(0.5*n)
+ elif p_steps == "num_nodes":
+ p_steps = int(n)
+ elif p_steps == "twice_num_nodes":
+ p_steps = int(2*n)
+
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ I = sp.eye(n)
+ L = I - N * A * N
+
+ k_RW = I - gamma*L
+ k_RW_power = k_RW
+ for _ in range(p_steps - 1):
+ k_RW_power = k_RW_power.dot(k_RW)
+
+ k_RW_power = torch.from_numpy(k_RW_power.toarray())
+
+ # Assigning edge features k_RW_eij for adaptive weighting during attention
+ full_edge_u, full_edge_v = full_g.edges()
+ num_edges = full_g.number_of_edges()
+
+ k_RW_e_ij = []
+ for edge in range(num_edges):
+ k_RW_e_ij.append(k_RW_power[full_edge_u[edge], full_edge_v[edge]])
+
+ full_g.edata['k_RW'] = torch.stack(k_RW_e_ij,dim=-1).unsqueeze(-1).float()
+ # --------------------------------------------------------------------------------------
+
+ return full_g
+
+
+class MoleculeDataset(torch.utils.data.Dataset):
+
+ def __init__(self, name):
+ """
+ Loading ZINC datasets
+ """
+ start = time.time()
+ print("[I] Loading dataset %s..." % (name))
+ self.name = name
+ data_dir = 'data/molecules/'
+ with open(data_dir+name+'.pkl',"rb") as f:
+ f = pickle.load(f)
+ self.train = f[0]
+ self.val = f[1]
+ self.test = f[2]
+ self.num_atom_type = f[3]
+ self.num_bond_type = f[4]
+ print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
+ print("[I] Finished loading.")
+ print("[I] Data load time: {:.4f}s".format(time.time()-start))
+
+
+ # form a mini batch from a given list of samples = [(graph, label) pairs]
+ def collate(self, samples):
+ # The input samples is a list of pairs (graph, label).
+ graphs, labels = map(list, zip(*samples))
+ labels = torch.tensor(np.array(labels)).unsqueeze(1)
+ tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
+ tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
+ snorm_n = torch.cat(tab_snorm_n).sqrt()
+ batched_graph = dgl.batch(graphs)
+
+ return batched_graph, labels, snorm_n
+
+
+ def _add_lap_positional_encodings(self, pos_enc_dim):
+
+ # Graph positional encoding v/ Laplacian eigenvectors
+ self.train.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists]
+ self.val.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists]
+ self.test.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists]
+
+ def _add_eig_vecs(self, pos_enc_dim):
+
+ # This is used if we visualize the eigvecs
+ self.train.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.train.graph_lists]
+ self.val.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.val.graph_lists]
+ self.test.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.test.graph_lists]
+
+ def _init_positional_encodings(self, pos_enc_dim, type_init):
+
+ # Initializing positional encoding randomly with l2-norm 1
+ self.train.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.train.graph_lists]
+ self.val.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.val.graph_lists]
+ self.test.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.test.graph_lists]
+
+ def _make_full_graph(self, adaptive_weighting=None):
+ self.train.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.train.graph_lists]
+ self.val.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.val.graph_lists]
+ self.test.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.test.graph_lists]
+
+
+
+
diff --git a/data/molecules/test.index b/data/molecules/test.index
new file mode 100644
index 0000000..428bc26
--- /dev/null
+++ b/data/molecules/test.index
@@ -0,0 +1 @@
+912,204,2253,2006,1828,1143,839,4467,712,4837,3456,260,244,767,1791,1905,4139,4931,217,4597,1628,4464,3436,1805,3679,4827,2278,53,1307,3462,2787,2276,1273,1763,2757,837,759,3112,792,2940,2817,4945,2166,355,3763,4392,1022,3100,645,4522,2401,2962,4729,1575,569,375,1866,2370,653,1907,827,3113,2277,3714,2988,1332,3032,2910,1716,2187,584,4990,1401,4375,2005,1338,3786,3108,2211,4562,1799,2656,458,1876,262,2584,3286,2193,542,1728,4646,2577,1741,4089,3241,3758,1170,2169,2020,4598,4415,2152,4788,3509,4780,3271,2965,1796,1133,4174,4042,744,385,898,1252,1310,3458,4885,520,3152,3126,4881,3834,4334,2059,4532,94,938,4398,2185,2786,913,2404,3561,1295,3716,26,2157,4100,1463,4158,871,2444,4988,1629,3063,1323,4418,4344,4,4906,2655,4002,159,916,2973,2519,1961,474,1973,4647,701,3981,566,4363,1030,1051,3893,4503,1352,2171,4322,4969,3466,1735,4417,1647,2553,3268,3059,3588,4239,3698,991,2030,1840,524,2769,172,4819,4537,1885,4820,1804,58,581,482,1875,552,257,2706,580,4211,1949,2281,3976,1755,1083,4677,4720,3872,1990,3874,3334,1559,772,794,3531,2902,3469,3367,3825,443,806,496,3298,2779,895,2036,1569,1558,4393,3675,1148,1503,3789,2046,617,3630,4508,802,414,4428,120,764,1936,1362,3329,3978,3943,1751,3285,480,1348,3104,17,3198,2172,3727,2336,3465,4552,3986,1268,1555,2430,1783,479,4744,4441,499,2569,468,410,4785,3905,4119,4350,1289,465,4160,656,1522,561,4874,556,1926,3307,982,4666,2016,4742,4870,325,671,3434,4781,4630,4282,2591,2136,1673,2573,1955,2175,3242,1072,2457,3745,2590,594,76,3754,4612,819,600,4404,1746,4144,1085,2859,563,2001,3027,2334,1292,3589,4450,2478,4333,64,4543,2452,848,1100,945,876,2231,2308,4954,1725,2808,1667,2162,4140,2057,416,756,2266,361,29,2732,1071,2145,3619,4519,3503,4594,79,616,1221,4469,295,3024,4771,4526,1213,3520,1044,342,2525,2987,326,2931,1720,2044,842,2897,4586,1266,1939,1331,1450,3377,203,1469,2721,3372,2032,1304,885,3133,317,3855,1822,1634,3770,2864,2500,1864,1826,193,1582,3264,2689,2282,568,2286,2876,4173,3274,2712,226,944,2139,1462,4756,2174,313,888,4887,3559,2831,3574,4966,4189,947,3155,4723,1557,2086,363,3572,13,4259,4410,1614,2983,3533,573,2704,2571,1020,2460,4154,2533,3345,2672,3296,2422,4541,1042,1571,3444,3105,1425,4662,2465,3326,4488,3,2489,2350,1721,3521,4751,2639,3809,3622,1750,4187,3876,1390,694,2324,4222,2745,765,1924,2542,1631,1207,200,378,3892,596,3730,3395,4715,1592,3145,4049,3273,1998,1208,45,873,3482,1792,1440,4243,3805,411,4566,2041,994,3739,1092,3806,4351,4578,4877,2599,3625,4135,3495,3652,1303,3888,3686,2123,2025,2271,4270,3969,1959,2249,3603,634,2340,1920,2225,2751,2619,4424,660,1235,1894,3137,1251,1752,526,3398,3339,2710,4445,3816,3406,510,1694,3441,3190,4784,160,4716,3116,3907,48,2881,2446,3194,3432,4409,4473,4941,1806,3999,1797,2235,3570,237,3185,2753,3312,3828,1045,220,3227,4848,4623,222,687,3511,1111,3782,1488,2131,2681,1733,3724,2677,2764,2279,3453,2066,670,3852,158,426,2866,1836,562,329,254,1633,166,1248,1954,1034,3879,937,4620,1785,2099,3021,1374,4963,4974,1341,2548,4740,210,2555,3074,3249,1624,622,4850,1989,834,2470,4918,4636,336,2844,4364,3035,564,2795,103,4015,864,3551,2967,3766,1253,3567,1442,4274,2212,4408,3960,3808,3568,4853,2198,2640,2011,709,2284,3692,1997,4668,4999,2755,235,2662,1489,3994,1737,2906,2116,2788,2290,4883,2263,4553,83,4232,1565,1977,4548,1968,3900,4020,3671,141,762,2410,1815,1993,2508,4764,3023,4534,4349,2816,3485,2709,2882,3717,2219,2511,1888,988,1577,979,4389,1516,1772,3966,2265,4829,4297,4888,2318,823,1590,2426,1863,2956,2476,115,1036,2247,372,446,4533,2393,4021,840,100,4702,2329,3845,3921,3608,2791,1510,420,2068,3913,934,535,3282,4028,606,4726,439,1242,1222,4610,697,2033,970,4571,3409,1848,4280,3690,3626,2435,4821,3512,2501,4658,493,4994,812,1702,2167,665,1286,1964,1423,4521,614,1282,21,3346,4864,3849,2385,267,1896,2360,2316,3719,583,1912,4831,1620,940,4461,1841,1220,2176,1165,488,1359,2364,3597,1018,3839,2491,3297,2230,4099,4423,4045,3586,658,4899,3539,2051,211,748,4712,4809,169,2207,1435,3854,4251,1486,4795,747,3850,2850,2730,2630,856,1317,2701,4058,2361,3280,4506,300,3725,721,2576,2067,2648,949,3311,4215,9,4444,3784,3385,444,1536,4247,2963,4083,3621,422,4499,1073,2359,3970,236,4987,1960,1297,2545,4512,112,4524,3342,763,929,3780,962,1261,4082,2390,4168,2239,3403,3952,3868,1996,3741,4515,1184,3142,1561,4910,4163,1118,571,3399,2784,4159,2188,2317,2445,4808,4750,4011,1217,3658,4412,3967,2827,2723,4451,3090,2636,1545,1956,4684,1913,3365,357,2606,3123,3162,1246,4057,303,4114,4834,2719,822,3606,816,4308,3743,125,1180,3358,1264,612,3846,2773,3256,657,2691,4371,2594,3997,4432,294,560,1923,2354,740,3555,3634,2453,376,2657,459,2403,2936,3070,3528,1192,2000,3375,1475,1392,1434,646,4992,1972,4076,4777,1172,1902,3777,2080,3764,2091,3811,2356,4477,1294,605,3618,2830,4813,2450,3475,2048,3742,2474,3151,3958,1943,3124,4685,4708,2423,2418,179,2248
diff --git a/data/molecules/train.index b/data/molecules/train.index
new file mode 100644
index 0000000..c4b804d
--- /dev/null
+++ b/data/molecules/train.index
@@ -0,0 +1 @@
+167621,29184,6556,194393,72097,64196,58513,36579,193061,26868,177392,194161,142964,22790,154794,110604,8331,7811,24561,57314,60990,132475,157815,6956,147127,52124,187700,170363,183848,142853,109974,57787,117757,154472,72926,212187,1703,198916,211240,41853,183013,110785,89194,72842,40758,56443,200145,88236,26793,24312,99595,25353,94104,90165,158263,69342,211583,11390,191294,120435,140568,32722,99230,20656,144714,76854,217423,164794,162141,94800,151349,50407,184699,18233,12012,173346,59742,202655,75861,20916,61024,26476,99647,72869,118858,166640,218657,95638,42638,97040,93132,54921,175682,69986,183977,179187,169878,18717,159680,166455,44862,140021,191136,64175,42834,121178,99471,70765,167772,180397,146001,57570,179467,85008,201408,203423,14663,60043,215430,8414,211037,82694,105162,70186,17350,55307,148682,188196,82490,55738,171819,130870,103712,168519,120285,37452,69436,36603,64651,195294,147159,141289,68876,195825,153245,112311,152969,104700,94895,57493,36262,133569,129372,23831,198123,12351,28743,40066,164481,41938,207638,178384,110666,156345,16653,100864,100039,156208,122696,138704,65906,145024,3009,178332,188932,30029,178706,140763,196838,69946,201483,168024,89174,29242,76939,113971,41460,118940,850,189292,188659,69045,131225,199743,46832,133085,27894,163918,78235,167496,133080,159637,52143,40065,98019,199887,42349,141394,204112,139029,149,157009,84975,128085,5105,29325,95153,218016,211491,80612,62770,15184,63143,148729,20645,22453,191865,127399,213915,18143,199387,139645,200758,32967,33657,172949,124592,144127,43287,69483,138326,159014,110923,55521,141373,197988,191347,180844,52730,186895,81714,104593,176078,170361,97889,114845,135679,118354,31720,64986,58903,16784,88627,5514,154221,145207,60323,154256,57728,1885,18610,185556,165439,15433,60015,17668,8234,86619,18574,134782,62391,73001,175368,127248,56160,141356,34684,189622,149695,151050,123907,63700,205683,123987,211680,106708,49914,24726,25409,172748,112997,92876,111038,107767,122427,191122,14200,176518,171299,169392,25799,15889,105544,190896,88946,209870,28644,65183,50224,49862,140584,117601,36747,110593,48100,73018,121275,65485,19761,116164,211818,144264,25666,13261,170955,141711,219159,3868,24448,197542,61965,43597,106539,127307,126185,56032,105130,15370,43158,99345,565,102346,69521,205539,205816,119277,74776,110888,182607,191497,205353,145691,173505,188326,127577,40579,49780,77780,57068,15331,151828,192869,142133,15979,196077,82209,14985,13144,153138,124987,131819,139231,41270,14910,133124,21000,48712,17962,155984,17815,177002,61657,105847,31427,149336,64543,151760,155849,10418,162367,21491,109897,172326,153006,148170,137044,82934,68358,53545,175564,187745,82361,62570,69629,103752,34308,176079,169214,78642,119858,82883,197096,19016,2441,120136,162833,147585,26209,19204,140937,55877,132614,69520,34722,91490,18033,64037,96869,74707,41352,114867,218561,142401,184428,79303,160347,211576,171435,138658,2050,175076,214198,145385,78480,173903,27154,35203,69328,30258,28058,194620,40749,71394,73860,158552,55215,188117,89884,53371,180223,166261,69201,132489,128065,65829,13316,24195,166273,111037,217408,72530,11557,929,87439,202144,34293,167015,68670,42357,194309,115824,144619,184986,112115,147038,2534,29327,19724,181146,39073,143023,9444,218784,96787,152701,144841,38821,112666,33409,10965,80808,95591,208698,10458,93797,55070,178799,65412,174832,26946,92714,204502,146770,106529,162702,196471,40515,62059,42598,209685,212538,46413,108080,6497,47018,193085,87080,205097,107928,210301,175612,192690,212533,65055,69941,41732,206405,183835,28336,100281,10151,123388,58309,52316,214063,120665,91660,80003,215098,208495,59662,58438,6203,173023,50627,104455,86051,73034,18198,202722,73170,92050,168160,133537,104773,178131,140565,86808,7235,30236,68475,46810,152198,69590,10028,28417,156387,113918,90619,190983,206157,82228,114398,158914,134066,30315,100976,151149,49828,66773,11635,185803,114309,443,136293,211421,141151,180055,188594,194497,193210,175801,51651,95478,113062,18343,174125,86559,163356,82291,32669,188679,78727,132939,81077,174821,107057,85506,105486,182769,77504,145335,33367,50289,110217,174307,99390,177554,196118,45620,161353,149187,78892,106450,143638,218556,106,79658,75212,55098,112692,205981,152039,159032,171627,84475,121893,115811,115909,177111,56020,134001,124042,208072,208673,192927,44483,172713,22228,74392,135122,174025,165921,162335,87867,24480,214544,196906,61569,176369,81374,58888,211435,52200,38627,6402,12114,64184,124554,160241,201454,19091,119384,108643,165089,150908,50970,188310,182545,100657,129598,104766,63959,38684,171981,180256,1453,196860,201862,27941,204058,111449,57367,46107,210792,182429,135779,121778,13164,146120,65325,31813,119658,34954,210086,121803,175001,139233,146518,156094,83177,197984,116017,160603,213649,188553,132324,111867,217728,143621,116893,41722,194944,124433,117983,67945,197073,64812,167159,72695,200753,203862,136655,127034,164298,62716,71984,115309,20311,187051,74901,61471,71228,88040,83809,141597,21122,36273,39539,60623,100410,181914,40056,185183,56086,16837,108755,106849,86738,142242,122139,108992,16322,54220,218337,110138,102099,201797,153112,182327,5120,200696,150913,99714,125037,1545,92211,78279,197518,102232,219081,109843,141091,195956,192579,143165,209680,158139,57812,127986,57520,71548,114251,127308,7608,101936,88114,175341,178032,209228,105989,189839,43265,122523,33456,163120,140017,7069,103290,155161,147951,173801,7104,22006,168492,112358,35572,121031,47639,13181,68198,99379,85813,55485,119196,85680,88473,199551,99385,72943,197134,218083,110510,66131,218866,21471,123288,5081,196352,141405,13653,91740,58778,170431,17987,204795,170853,10553,197717,8134,64823,52261,219998,5342,162879,39946,62533,33088,124141,175494,29986,147841,57138,121905,183360,67172,201037,96703,43984,158831,159186,196064,188313,30024,203893,214774,42930,81536,28337,151698,6731,81777,150941,177562,98384,103980,187436,51990,19922,155214,181040,217729,164427,63661,26712,182764,202501,79058,179374,157394,211164,31733,208726,148355,205163,10766,91017,139655,112296,173413,97142,18076,132634,169749,89451,3316,110115,215569,128503,27666,113645,94945,166614,217240,120517,185416,40105,114160,46173,192360,136772,170513,70800,161457,211864,141078,203067,126745,121863,114183,216461,191634,155313,70358,84490,64355,217771,22718,73119,118175,63927,196732,121820,149382,159994,175161,99349,88184,7523,129579,85200,47668,127808,55605,93015,209146,67725,89217,73310,156278,183811,72422,145697,2661,135430,50084,22442,63270,188763,106542,128077,145536,198748,62999,181039,124804,169319,186605,128665,117486,207862,4520,24393,77132,58090,106011,181347,63780,80270,174053,152451,96736,124062,145089,139177,90113,111543,195542,144281,86714,92225,184249,118946,71019,80379,65903,60434,31629,189080,50484,82719,31340,194741,140473,199803,180922,48535,50211,56723,193620,126929,72482,189945,154556,199283,137530,156445,74186,26352,218268,50886,77659,59633,94602,47039,79237,3708,185602,140020,33182,71909,11931,14293,145059,76581,182823,33103,167213,197339,128680,26892,3215,150487,74537,123049,125492,115466,89314,48329,13468,66185,125233,29907,215512,17129,105043,128908,19420,151262,165005,179949,14053,39774,39111,212636,147545,79648,22329,65061,31051,146295,200394,109097,158942,156271,207287,162114,59162,203343,136989,99713,118099,116056,77949,154278,112415,80053,149062,162798,15789,159811,194009,26013,199946,54470,163975,55318,69375,173127,21282,41171,62879,45564,144701,19677,41034,701,107090,118096,180714,155664,123184,76351,8556,60680,75524,185324,74112,184283,119021,18659,180193,61190,69351,206524,207420,163855,154609,173325,210739,51857,111447,30087,142753,58932,169773,39056,69633,216696,37286,18719,15633,43495,207782,80638,155987,196334,216069,149214,75657,115121,32598,122866,180532,79715,183430,105515,71367,131195,141552,129445,114755,21087,156771,10449,113253,192553,84494,158259,65632,6780,23940,60011,176763,219141,150784,153910,5438,200477,176234,215330,70650,151059,10546,200044,198251,45925,123338,136043,170789,115927,72917,47576,153440,114264,166405,213449,128902,23918,123210,91215,107046,87374,84163,175671,27420,42159,86456,107910,181842,129884,75554,173692,104976,213271,199360,144204,9619,119229,23084,82448,66164,84744,30388,202518,105952,134898,216249,301,172401,142237,121104,108330,14209,49173,135905,94838,163219,198297,130676,163947,115878,199226,13529,53361,70007,143974,34344,75505,114849,183042,127064,31831,7567,165156,159612,209535,62730,186065,41517,81461,144399,3584,144769,106952,24434,29743,120965,30793,169793,218141,40362,130646,187853,76509,133397,184919,71676,108918,218817,126484,123789,63892,119743,144510,37917,100554,49967,157123,133232,195636,35784,18300,72416,202552,207095,108774,89096,206495,133100,70037,215102,672,74144,190329,78264,219539,153862,152023,172999,128356,38953,117059,141185,126967,90472,87147,144681,199982,142456,98883,119365,84352,49454,182845,62601,149893,100393,61226,203304,107682,11441,83408,195219,184871,212705,99937,101208,173982,207721,215154,170922,39873,129847,9704,33094,131672,154712,87030,26309,115423,26138,137874,119780,4023,189384,37789,107473,171646,40464,19615,123074,204872,69474,88751,163383,181588,104197,170350,21054,86131,176778,139885,99617,83010,164311,188407,199072,127912,141824,9410,161863,17936,61543,165455,179411,75334,59634,195760,23692,113763,25806,199326,166133,184520,26347,116307,43611,181928,78503,7588,12056,85032,208704,14711,76904,93969,98262,112901,38160,64014,139242,108016,148354,178729,207754,47200,44560,45893,20701,159774,162453,179079,63132,130460,152871,37517,60866,120887,167222,66578,120473,66932,174808,2463,210929,121963,75400,177631,143287,41412,19362,115796,90587,154028,78421,167493,111230,180961,65561,119756,79199,52223,100845,126670,27958,62182,99972,149926,94098,150683,77561,183311,77392,5751,217589,172550,103758,71953,2122,179778,204029,195210,12856,158965,195342,130214,218323,75024,203355,209416,60324,159138,210221,92358,57410,166885,49841,162762,65700,177671,198070,188987,201187,172800,178482,219765,35855,164691,25470,164484,169434,10334,80984,206729,115559,8746,151931,95646,191983,34448,23627,77361,85647,195947,108921,46042,52637,34644,206179,141402,95907,139159,131542,71441,217703,43134,67363,216187,126313,211416,77369,195706,88792,210824,30191,122770,19739,36898,197696,59143,177294,189849,176791,104179,210917,146096,95885,23672,207272,103435,3648,69326,140659,32398,119219,96622,176377,196348,176250,68724,153238,99883,215591,167391,97379,28402,176904,61295,123594,6560,162410,147169,85984,159929,58030,169778,16571,166567,215968,121704,183783,79217,170194,107032,30585,36641,11887,9754,79787,129138,30440,25478,61551,140914,35563,101881,118919,97258,175763,194809,182579,141608,109871,153967,194581,190473,40507,108759,171683,25957,218547,128279,161389,106985,73332,8576,180952,97132,56955,116230,116574,61895,95075,26045,179746,96294,142728,169037,94024,15872,104369,72320,49757,32023,216056,119201,24031,173740,55602,168218,167639,156538,5598,13258,206253,87427,63850,33010,206301,148000,53793,17985,217531,200581,145282,54305,153716,56610,213136,61079,86129,202996,38680,206738,156233,743,72685,37929,34076,141614,65707,209325,45743,28816,173292,6758,34551,3895,93911,207089,206682,62372,154364,84874,4136,45677,69564,13736,33228,194436,110352,137910,29792,195471,16661,124845,117512,203952,94906,134542,155625,28587,118490,132078,58077,161246,11367,190642,205318,172727,136695,79070,120073,168647,8165,15945,125562,105281,111745,179856,28301,128521,186751,116277,19265,21178,84439,159461,38884,17218,33080,72093,163661,165957,153449,143748,186686,85245,99851,156602,139082,77305,118938,132527,158709,112774,25999,207905,183967,29992,171629,170633,201578,144529,188963,56367,112740,118372,59898,108478,88848,216902,118882,104524,109049,191263,24926,81932,111873,81926,174354,66817,98120,40013,180046,124326,17598,23914,218044,22377,24439,113213,25313,195189,193670,97687,212800,34108,145849,15723,153738,147216,147241,86414,175639,32042,107691,92693,174415,196682,110870,189021,13486,157393,81906,92181,27158,151497,133015,55768,40561,172159,126403,58784,28368,91780,145821,96353,30116,199912,73024,150496,59286,211608,112489,147053,201177,214545,162981,160844,176941,168479,145945,6882,159644,172446,217438,181789,70109,7589,47294,71636,184208,199861,80998,89082,92018,1600,47554,37552,148457,172318,105063,18241,37191,194240,165982,8036,24053,195588,139062,56395,98617,110056,118920,89363,41263,97007,81693,189162,85035,203642,148791,156296,22270,13791,40784,41264,197770,161962,13044,176676,21386,71330,116159,173523,111152,127313,159141,115879,108575,71609,56510,197903,134308,29836,90484,112699,29069,74251,177790,177861,155486,127567,138132,174884,80850,11905,57807,103615,157109,14360,2015,53588,79015,55373,201156,35975,200320,66982,75876,86015,31453,2026,130386,112893,46057,33865,99669,139613,184463,60322,131140,146466,218458,175124,211340,92842,18894,104089,194521,11089,114352,4912,120521,20411,82066,150931,112534,150313,106017,185990,167809,109495,75894,30192,106176,5467,85136,45055,210061,162037,120611,218110,180777,94857,23102,114481,27751,63783,114209,154369,104975,137353,20616,103747,81336,195508,89001,58083,87311,204241,44074,20022,133820,165996,29891,139094,133680,50830,203483,91586,92038,190678,214682,169199,213657,38716,61936,26949,38389,67100,51713,45481,157915,40073,199285,199007,171837,19753,46437,202597,164613,129529,121622,197773,147779,199197,151904,117677,178555,147978,168556,166539,163714,84727,164421,82873,39572,115281,17900,122924,115922,165513,79382,208753,72004,155037,14725,92259,132995,19448,81371,121043,118466,9858,14913,96662,218024,75255,20114,168995,23670,161220,155734,132924,100789,121297,152129,145310,207570,193703,10753,117894,212288,149856,170875,49345,84286,158598,124683,131430,39547,16219,118101,27132,212590,219513,90030,187253,22106,132264,169378,45235,10260,64944,185498,115191,137381,137035,159856,41614,95393,97740,74158,101558,107156,203021,88700,177968,156732,13727,206891,165383,169687,87726,17278,86426,24774,146266,177833,101335,74487,66059,189722,172114,157994,39419,87388,21365,152761,173961,37074,91703,81325,171980,183102,173574,102749,33803,155986,185816,22208,81155,98744,168687,207529,86108,213171,33500,175666,184242,217306,193665,179677,137990,24498,169351,175756,110993,133297,94851,4779,95052,80967,47254,56149,89583,200843,127473,50333,59386,36061,40617,20228,77540,206600,26513,133087,202080,141497,218921,193631,138027,9895,173520,88280,200882,162013,34353,156564,98754,40436,42532,47407,217932,181625,202138,163771,212085,43395,189053,114733,11453,107704,95498,177304,188635,62267,116444,160063,74692,197234,196224,205468,117683,61342,139987,62696,81125,211975,205617,122949,218919,50860,96427,177814,149532,115499,121026,201703,73864,203980,100105,131833,138270,109749,42482,214143,52349,210045,158508,36280,65536,13678,168081,125983,97311,145369,26893,186489,135242,32683,74727,21987,200046,42024,71510,117786,134572,38626,217853,114680,24042,58205,214093,118260,91641,6991,108762,13960,103928,131598,98031,61807,101231,21392,98281,58861,7389,83541,25974,219985,187355,170383,87894,207455,38285,36070,10040,75203,217517,123851,182471,217870,36383,198925,184952,122974,117598,161309,1376,20768,4981,67088,56532,218880,39195,143865,190640,159614,138352,110953,29146,203468,62295,78944,31941,12517,62505,110054,167497,208115,163266,119829,16464,29060,219332,131027,156431,140523,4308,165663,135102,150745,63422,188362,37637,76355,112522,410,161168,92488,63062,149596,109200,49103,174160,175168,22443,137228,94530,17741,137903,142627,132992,206153,133070,145267,5330,102355,123243,11412,166637,101418,97858,66492,195904,4257,93610,206705,17710,90387,63209,192136,172255,164693,27174,202244,152542,192724,198358,87167,34972,11622,92353,143134,88748,213256,168505,45898,217680,204247,179485,121798,182292,125444,165605,47783,212729,35350,16541,187706,203473,119977,9695,76924,52845,11483,207384,52289,10985,82731,81284,135104,104392,213753,142352,124098,66404,9599,197470,169581,50096,75002,93578,204692,12531,171896,87012,71622,32630,209542,96474,114544,104863,194853,115274,101356,88876,48973,130076,181421,130422,96297,209005,136125,69919,210016,21656,190600,111280,20690,112881,157929,215571,47305,142992,77009,84203,26895,20989,85962,173256,77500,80346,116894,157973,188131,111703,43684,180817,116367,92155,117212,11103,190530,92421,161179,114017,71973,167633,207481,15036,19671,175946,167120,106485,95239,134502,210002,196581,178103,41939,8163,37428,159271,177756,204816,114864,9122,33090,17624,61864,204083,169125,96030,94935,100362,148699,8489,158621,40222,178012,117944,97246,97521,116376,200038,20155,150450,36095,138794,96193,104342,82414,170278,73056,65434,29739,6793,192852,48771,130821,135723,101446,147239,30870,68623,203152,68233,184559,116987,56248,160399,74862,181922,128765,52475,32148,35580,19412,118500,45260,187005,116683,23013,212471,178763,83778,175070,91075,186001,17023,144178,142126,76116,78631,41289,186596,186007,183558,167169,45636,208062,94763,133367,58783,31814,52685,207635,36405,62058,207191,129504,6890,94606,145233,150019,96707,122527,210675,144610,34018,160418,22600,17235,81078,104375,187939,188447,125479,137807,107687,201631,107330,215833,150706,19333,32853,83106,168385,19441,117958,122115,178286,135622,90312,33642,217331,204541,144520,167718,154084,47700,201269,33825,113408,131822,14526,217944,32548,135817,40112,79724,43163,42458,84583,185947,59105,90707,136025,74461,20672,65694,51459,166476,71945,32812,163845,79383,161068,139848,24491,131746,168031,44206,155139,152212,40421,44879,172682,163703,188909,158633,88504,147749,10798,216330,7437,21292,11917,168153,202118,151197,69377,170689,55262,201053,149982,109205,162033,167507,7939,130557,164407,143019,75935,168272,79192,126601,64193,210944,211641,179566,106454,77970,118893,19129,180515,15705,41439,115239,108952,126943,121750,53483,89179,37661,81941,188314,83706,192466,90536,104521,34279,199403,97134,134998,147233,27835,83654,63370,122282,32118,70116,117858,64990,36930,25378,13265,76077,100702,161317,109611,65082,41864,213248,85835,151437,189039,81976,49763,200047,41770,130617,134932,122372,130746,80882,130382,6075,23605,103086,132499,119860,63105,56398,152920,92509,12755,13248,73755,129733,156635,171510,176297,123361,74907,140690,2143,28179,112965,35110,69329,190711,95886,200240,105710,95928,11860,105010,13410,149495,147276,51027,95033,75655,19275,101279,132187,118057,200383,144141,73327,216350,163379,178362,160181,31145,33728,25344,103259,97831,208377,88889,146232,95841,197909,37818,52181,157932,133463,105221,131092,10552,11876,10201,35923,187063,87304,210754,124189,136139,119796,39043,158927,135136,36642,85974,160712,83500,103010,161611,193808,78442,155641,88178,132989,217131,133618,139549,128410,185791,147567,78588,124490,213843,4361,96562,86832,176634,28732,109193,153014,80655,208864,190139,180366,164956,6985,156476,124410,69607,171859,205233,203466,151643,151423,59649,189110,13462,152967,125946,44698,137424,164895,162460,203058,99664,38738,215393,178448,63490,8276,150095,183637,28808,50012,4964,115615,82232,109780,39696,108255,181010,53448,107577,131548,203154,160295,123628,192895,190684,16326,185014,36197,135969,54353,147033,85209,173598,125402,137750,98782,82248,45398,120498,139756,89738,143202,92874,177174,202372,188986,178453,168639,210672,181747,69300,159927,126924,50400,64623,73139,146299,78249,58947,78017,202158,75760,184785,54336,180942,184808,128189,83110,125777,91436,146928,208510,189179,71685,75442,31948,150333,177372,142394,99534,103424,214532,90474,202654,210925,38397,76140,11027,187207,20717,90797,115966,171959,67225,195910,125587,56095,52970,217042,141184,71075,147329,182497,71203,35985,28633,161396,194220,153834,62726,63558,13307,175592,139209,59158,167233,61036,13760,26266,108344,86584,188051,123837,26357,178570,202051,36036,1358,144271,41403,106668,171151,124701,125160,170534,52286,198157,75299,84192,74819,169318,15509,201516,23451,170996,150505,60875,140286,193642,189560,9777,45892,109546,219573,46081,9526,219205,104069,206559,129861,48840,196346,75882,9821,2429,78221,148887,158015,28128,87829,74596,119104,168116,142422,137446,129476,35179,132176,122688,71522,50544,212025,29529,86686,42596,191658,120301,169944,67440,188376,48810,3692,193099,88371,207294,77338,148840,176790,198517,50636,46004,160158,167438,106296,216790,112119,135077,85964,22737,105076,175441,25004,48439,36860,125161,84830,65024,1790,68363,100551,61740,117109,197592,69990,86409,79153,152688,189499,150176,2997,68564,171359,94254,181605,61945,16285,175056,30991,122000,80427,41868,106277,179850,131732,184433,201809,81521,180955,30779,167434,77342,96349,161120,57894,57443,35014,125352,40156,119269,196014,158790,97943,109011,184009,143978,123424,198534,140907,210238,174088,216427,57255,199716,64871,178347,197850,156189,206404,21461,137758,117112,138457,184488,94839,20474,147893,29414,16189,217258,143562,132510,52994,150130,140682,39282,43132,86014,136404,115825,30471,178335,53844,187935,152868,128143,23839,133791,116780,211993,14584,118864,34588,134544,108901,119746,147825,15135,146485,121209,176362,211528,80784,189603,5706,103795,66704,214028,790,57121,151600,19144,11872,111069,90316,183601,16714,141862,15795,18092,123750,8319,75259,107182,47181,201581,35501,200757,168171,191178,169313,110195,98162,100264,117581,98958,98473,21037,178973,173484,141374,34862,171320,91165,31098,46785,140839,103037,138632,33383,190803,58376,218843,910,198227,5959,78214,121373,176553,188506,142723,111182,139443,99425,215838,60205,64883,120750,90743,40664,72259,49409,190060,200668,29581,8445,212199,173066,109901,161198,200816,4101,63089,54023,17639,26449,155698,8801,117031,156705,176166,184362,12760,64155,194128,11596,105451,115071,61430,141494,197757,203440,14796,36735,132093,75828,61384,213853,191762,151023,83598,151381,156740,202516,176300,214534,84060,62102,79105,37554,173070,136642,57945,108412,78774,72129,145940,155101,192731,45910,164100,177930,111988,145712,129924,12290,90286,168641,175988,99917,205993,137484,83554,182533,109137,107022,39116,78613,98686,48202,197765,141014,124130,63168,59072,78836,185487,37840,211229,121379,15105,147426,108180,109185,146089,138934,35151,101821,63663,66845,53303,86576,169530,20727,117908,97275,24238,140413,49933,13510,70375,98872,176599,158589,158052,10367,19079,49325,210934,199674,154031,189872,176016,146499,56889,125670,54759,87284,79478,4019,55524,49813,194373,30753,195789,197743,125546,63594,182354,158440,53611,103974,62739,144724,84309,203296,74202,99788,122224,139893,170058,94177,80851,68566,94278,134431,130358,122160,25849,210453,189376,199656,83947,53248,97178,82030,108509,11984,147462,58005,194314,38252,4251,68360,145047,153100,151818,188746,109465,77382,39953,51430,86472,60298,99572,149366,218461,64416,130888,144347,171601,179528,88490,67456,200110,215375,128089,189676,168216,193880,128606,120803,44101,192091,208398,92320,44371,36689,188825,143273,128009,48247,142117,168246,15396,137348,8866,219956,19485,214432,12717,200366,1673,108106,36042,165334,60668,17817,185380,39591,2396,57336,132574,119540,97832,15846,161961,167176,174785,161616,126669,172827,127888,4123,1733,139516,107829,3106,4384,138783,189318,72021,140440,75153,4493,131652,213164,182939,176693,112872,211332,46993,28053,25248,137464,39032,50313,162471,138007,66169,213703,92914,70077,208387,104044,20714,97793,106437,120323,148111,63755,182776,59207,78636,179317,217071,21133,171372,170365,199120,8305,24510,106241,99421,144887,124781,14710,166919,2477,184246,44927,21675,131034,113746,169058,205619,86840,148335,24998,138453,11154,60231,55421,181479,148579,124649,71065,12205,19489,179418,73498,147720,172519,8625,47016,82440,4104,153887,37798,197198,215139,186833,215194,104410,20168,78453,42585,147838,63038,148555,218798,102015,177480,141705,86960,100622,198078,193919,36687,207224,180698,189835,20559,131279,195416,90664,14076,25607,114580,60855,219853,19927,89284,158641,201590,160921,156135,103993,202782,85557,7766,166231,71651,205402,118109,128608,59805,93273,144975,98415,113206,48690,178177,153583,173191,99633,22464,202419,162166,77431,209205,64605,186937,19269,21505,70247,40428,99614,186561,205492,166173,40442,193896,102137,83015,94555,27931,23991,1361,80810,116584,94350,199582,70717,26799,34735,22884,49123,112978,117661,146010,145421,134333,106966,27077,6884,23493,92776,145285,24396,156480,156773,205635,85047,100987,3235,76506,108477,101659,204254,22203,190013,146679,63633,149935,136604,44455,179651,99806,44433,36396,70495,78939,70361,129269,38255,16585,43832,113947,72341,110389,78578,126938,205171,20164,94666,65981,64642,188979,163848,129971,155662,161656,51261,120025,28042,35517,79864,1774,103728,87088,162579,99585,210783,86387,115401,87749,112883,213772,215884,170665,155816,35684,78757,158128,182266,52652,125569,46566,104396,83660,76401,192642,182179,165911,128714,150963,204989,63550,85429,98580,73479,214039,206010,103005,95686,29855,147810,52376,155248,143033,47166,201226,144205,7055,190634,121011,185694,54958,114938,209108,76370,217915,181622,18129,214686,208576,208416,107254,176417,130889,36484,166025,79385,63444,66238,172664,40407,187494,110920,206851,98825,19389,117720,156744,125628,152631,105068,140081,132805,181327,109655,142513,9637,207282,94248,183929,212938,140746,155666,167295,22277,28670,201576,65072,172386,174245,93063,43604,169607,160609,11612,147644,169091,177595,169539,104757,197311,86926,208727,112841,27772,2765,67826,58196,133870,195783,135999,146274,152179,180267,150794,57942,116858,98073,102533,121389,202464,178141,154295,180993,131194,39516,90431,6438,126555,27406,77445,109045,22506,30426,217924,191404,37139,91312,81760,89932,119362,205591,54229,136607,126596,91304,124823,25591,114862,189266,182652,118606,83600,17609,78712,11713,212385,186326,30168,5920,89935,169861,28668,177508,209728,43084,194114,63940,135264,45731,144679,41978,86766,146761,111770,121124,60779,209126,106180,165087,48352,48860,167897,172309,113281,103699,7652,193301,160984,51727,117853,155242,112604,101930,1166,184742,56307,53945,72966,197217,184486,212003,208668,16400,151348,26673,209584,140742,49010,95890,85454,51661,119944,29872,68774,175820,128315,138337,167538,82157,156492,101879,160250,102903,153879,29945,91206,92270,161837,45227,213316,176577,185070,209875,77973,161083,154709,22217,176189,35113,82099,30896,62540,30657,47984,97859,181780,37056,133909,101833,109638,156349,36141,150681,100587,110897,48694,166592,140870,181779,169151,45412,145467,43742,75716,36777,49046,82465,118227,162374,94119,2512,127037,35791,51151,100648,147215,132497,171198,129753,107252,179489,108955,186159,116353,128109,43844,21671,148215,8089,208635,199506,57810,76569,8570,72051,58941,140930,75673,44094,204666,119884,148281,194646,200781,129843,144107,133904,29677,150133,29938,70225,203191,142372,215818,96170,142093,215642,198386,10907,199230,188476,115701,143223,57289,111451,26847,192935,217598,171356,196960,64888,78042,8449,68895,91377,214100,22802,115099,30772,203327,212831,62113,55018,212520,193360,154120,180652,91915,186319,159976,165877,112241,43511,161520,36499,205743,54164,213261,54248,211141,15560,149441,91964,139534,73684,156462,141234,44139,84843,184563,76712,75846,150139,70082,134850,178074,213694,24962,35543,196632,205964,107897,15541,72698,171939,33262,183874,34784,65451,38336,186487,84435,217181,64894,199615,177270,102944,128308,37269,151368,165780,70442,164577,108586,98349,118486,19796,166101,208299,201558,24404,105604,135164,196556,72935,182112,96749,119195,127937,85796,152556,648,202934,192079,24416,191609,120179,166419,174101,182790,93905,16601,208582,140096,104258,57102,112628,216228,55535,129686,70395,84637,217145,74342,88540,143474,151501,34234,148410,127228,206649,90109,178926,199820,12690,12043,26102,164149,205491,219336,120476,4268,32568,41414,115680,119673,177,112145,53036,180510,34580,170392,79132,41674,71832,24468,171596,94485,65638,21606,97476,174196,170108,43391,13575,103898,163998,80087,190313,183354,197979,61379,112520,171831,23907,184470,24951,284,55947,124992,20431,34786,155629,59294,136169,178253,116671,2434,2113,182346,90065,214200,31612,110707,181936,34648,125226,18567,60017,100419,23417,191565,211276,26850,27053,82143,96354,78447,35947,100328,202758,213219,197435,34903,168185,177507,37448,17933,139069,147663,2221,159991,170398,43022,115369,92061,189381,55851,164647,196154,39681,107927,161500,179979,115968,56949,22724,26336,36703,197249,32284,154553,191367,100671,92152,112561,82445,207376,36655,64955,72605,168983,22085,145069,157548,157288,189052,159276,74690,204202,181004,7031,172597,79075,53856,135724,159102,133568,49452,196042,102695,77474,169389,14304,208084,62741,129731,101157,29555,63198,130962,167269,155695,18615,138487,3163,94659,82961,34598,101407,219931,148997,110326,95790,143224,179029,45290,199944,123473,201947,19134,4375,153973,17703,3615,68730,56555,10440,15851,206907,104326,132804,75101,164776,186811,131346,201093,108903,183975,105454,21607,167913,140657,140943,161493,40024,72736,21914,81240,20715,134035,53267,211828,40731,140611,85655,102427,153730,166276,200074,218766,170351,178534,168599,17331,81309,183153,114684,191354,62557,15179,64088,113897,30405,118848,160509,159601,14425,174608,194430,172504,194159,46044,31311,3320,185571,35969,183893,2975,42947,130131,90886,137102,135586,205255,215615,188944,67960,43768,97483,33274,196549,204766,70454,192134,30947,202225,7695,87821,211578,112544,71298,137325,16743,68930,186731,151165,164503,20203,130062,119327,133898,94404,14533,130969,148605,43746,96387,41237,66877,202723,26975,150447,176716,189472,30234,59874,194191,133385,673,11544,3221,64107,12053,123670,95414,100395,39343,46913,9270,144930,210531,187176,171888,109935,58855,84358,65080,109119,188819,83869,71108,216530,20417,149824,97729,31300,131607,176508,13807,47014,58506,210852,135207,12039,105100,18068,121921,74005,202775,81447,85799,22645,144989,119733,2089,96709,52604,76439,147798,205653,79668,193875,163241,63591,121615,96540,155108,129331,202535,203393,51228,192010,143277,199355,208659,64258,39470,1671,107488,6267,61078,141962,90812,166708,181786,2515,87822,383,198948,172308,98638,193483,190635,27249,53790,138887,62434,110052,128982,15701,37322,186555,73354,24386,11508,60679,136251,108280,183722,97646,120182,196057,22256,152245,25054,132494,34999,167042,214948,103376,19616,148850,16327,114028,173655,34264,62068,76461,68417,82577,209929,183691,218077,192768,85524,83376,118687,71291,61239,19578,53139,35680,204455,201889,153277,28566,40739,27697,43089,117851,122173,82321,106564,31650,140489,93922,201939,54430,118467,80821,121463,69470,31434,23703,41317,211706,179845,79748,216632,184734,182682,158152,10572,56366,219132,85929,38882,23910,186425,65069,93255,103873,134600,12832,177519,78339,68201,214981,45578,8147,106008,118541,145739,196104,144210,65497,24831,121121,26714,210736,208092,36645,31838,3134,16166,217677,204723,58271,34266,51939,215921,97580,179905,165228,169000,22030,152652,153647,67848,213056,19657,5714,16900,51839,169341,116242,212152,33761,24111,217410,212044,86913,32119,11192,120847,13274,44037,150440,113549,215399,189905,103546,130056,7728,100330,179100,111847,45300,92784,56286,49060,72084,73362,116919,38962,9083,162313,161736,64431,169680,77360,130781,108001,144359,125575,15985,22995,73341,100405,109899,168467,208980,137839,65312,165348,210283,142516,208364,5119,100269,204974,94139,125746,142742,204848,127055,90186,148282,131699,84306,101811,70893,47137,6832,83591,156530,57414,7681,212517,204089,73644,15359,211983,123649,138787,93698,201023,153089,61017,41844,26249,64974,172193,63216,69868,139903,213984,207418,193215,14656,200291,57517,150914,209031,101822,93009,219997,45537,46163,62010,155325,83362,211343,193489,182878,94052,155074,7394,184405,183647,92191,148806,147746,37251,147615,49344,219335,212441,129034,142114,81016,46021,128564,9951,23660,14724,60980,157722,57359,5224,138005,125301,179,86962,160991,52812,211404,34207,89327,187940,46520,210350,84826,15612,5744,38721,153776,184731,37133,202995,29077,137965,95653,19013,97934,184690,173575,103703,153620,26484,88227,84568,35769,41288,191981,114522,207576,167579,127716,170178,83173,45666,185383,147318,181812,160880,206305,93539,58520,174465,154389,213504,45967,99191,80431,192412,180485,77205,46585,191058,471,184027,150165,102655,202068,148762,8488,47731,157912,83418,212023,160269,57906,167679,148255,27291,130071,36868,86813,195387,20396,62389,90914,83759,44010,166772,23258,185851,175642,167536,191177,88014,117081,2987,69523,54967,65415,181711,17771,91876,67281,206312,28320,188852,335,12621,101171,114876,192042,110009,108532,129258,99277,92073,143136,98395,26588,217266,125695,206760,151533,170647,198768,180134,59135,42639,118302,19322,209504,8758,77308,5263,83549,68453,27453,19451,89602,44592,98424,42180,191885,19661,143404,24406,88591,155721,160863,126546,184681,8124,113068,171346,43253,159204,113823,41183,14039,26544,86794,53883,49803,107290,182765,145257,189277,191358,141590,69537,173171,73799,77961,62271,25143,13011,102996,149886,144090,128699,40262,14288,94589,1000,112094,22971,77214,215979,173544,163863,157342,125915,52343,25424,6746,54251,44782,167297,76496,21479,124143,30279,80832,205750,104193,123455,128213,173287,69469,23890,168261,142639,101587,48465,97260,99993,96836,48911,116857,11696,69021,115275,121912,69048,59308,219435,70401,147913,16158,39617,197193,182575,174972,25350,22467,175149,90122,211859,141937,108524,198300,154593,59641,145123,15986,101477,214540,136991,109741,140381,180110,124472,61772,124849,78342,20697,103495,187250,9103,131622,149741,135480,209833,149811,177611,162600,38652,31531,198492,117713,45319,43821,55545,211933,50521,33078,11187,110985,20543,179077,114135,52931,215630,164423,38765,218656,155988,67603,83820,190605,211640,17566,22813,101737,145847,104164,144717,85492,72404,136233,120103,3241,206189,183608,163330,154029,136986,204914,109994,29496,39227,38849,148529,154123,152745,198804,26267,27691,147504,204352,26474,205978,74997,90143,107712,69241,100412,171648,127293,149952,159635,124006,45141,72198,105637,37710,162069,159555,178471,181839,105308,10166,103754,86899,181014,62279,12921,195493,124575,70407,97633,5303,88569,214272,208051,79444,81513,73491,128875,182500,179529,25747,59757,35228,78769,194962,116158,202907,216484,84113,70626,191050,109254,158591,169228,23406,49266,115793,55414,215998,106737,194234,126998,199096,135397,98018,215224,14734,135074,207484,41673,17201,81559,186009,132853,105744,35559,204607,137438,149074,45544,50531,218347,52619,219408,15106,65402,8669,119655,13253,94470,184715,52195,71821,96561,215877,121692,132415,104056,166568,32472,178565,8011,97639,128132,155724,116943,46552,124408,153847,144528,91687,90242,42864,219585,67806,196335,217807,180578,23695,74729,7195,13015,42705,150014,204427,56182,58025,169706,59011,177653,54193,70538,170809,107349,134553,5256,204551,201214,178834,2332,124281,34685,169704,45403,157585,209941,2218,58659,66162,159473,211048,80074,186840,185956,168490,72188,111098,98690,91412,57202,122342,79205,177522,136482,216088,163474,104111,153219,26820,1926,134090,173632,98102,147182,169696,155653,159658,74102,78243,179414,28338,126284,16734,88154,72478,167928,85232,72583,68458,171997,186408,170778,78326,49377,39303,135605,200664,62993,14883,211138,210334,212344,157295,105931,173154,83777,178456,36681,6672,171809,130383,77490,68007,109724,106063,101783,193129,9258,152836,185616,149957,213913,49785,89959,190460,58291,179914,140151,167330,205756,124711,176126,93977,78694,172389,146341,47483,77013,25586,123412,33539,194227,67804,190103,211902,146546,199994,47990,178182,87984,171556,23409,58262,93135,58544,195090,80500,110224,206593,86852,97434,68304,153503,77083,121572,30875,123348,14071,166303,152980,152125,162700,18668,50998,131197,31409,178440,100052,138893,78212,108293,14172,39525,36009,52506,89828,107267,149265,119949,37545,82938,187608,47279,204194,21184,127816,87196,165421,46353,82159,166330,15207,1235,118503,71703,54888,199542,202779,44198,151068,200404,41532,128339,205838,200165,204407,23079,33922,161082,113694,166064,112567,105747,112466,125147,100484,1223,9889,139927,52275,191537,97806,3690,84466,198403,218425,139207,49958,4696,177405,1081,164282,192309,65227,59314,181803,90269,81813,33798,26924,100545,131946,154316,80246,43783,17114,215965,78171,76238,119518,190796,136390,156334,137394,88561,113682,180035,35072,89845,128497,94175,196638,49394,43223,212863,105746,5013,59878,58804,190622,34394,55362,192433,154584,132286,44298,32286,96099,186413,170535,207651,9794,98354,167856,67053,202178,140177,163351,12736,155693,12572,173531,28341,171328,219183,5175,12644,180813,30353,109255,116474,99681,32399,143466,66253,124360,183990,40141,53869,182240,179762,165232,2439,78935,110328,173451,26471,173665,136395,70772,163262,158763,35956,109792,217750,27502,134108,188420,161925,30863,74109,29887,205720,27856,130479,52074,159646,52363,203937,68942,217972,137139,52867,93827,184401,107415,77424,42040,10329,143691,130116,54445,180408,126003,86665,61474,717,3342,176064,23331,28948,218951,150510,175380,129526,39097,23614,204348,134140,19673,189677,189681,26552,67080,168909,60597,119734,77600,69238,121707,13128,26164,213169,45797,10340,76170,94428,210133,177410,81974,111672,193603,30382,24903,198758,12029,2535,36279,173894,169659,86156,93723,114744,163002,70335,192169,24366,97680,213637,89842,47938,29840,105061,104923,119678,71259,219075,101113,180853,125369,209688,110556,170091,208905,43720,29637,34737,182496,190424,15373,41784,27187,109059,154861,126530,210308,149145,177836,114861,47516,160254,99549,204253,94004,160263,7554,179428,219817,190651,34250,126636,127958,29605,107815,115523,12051,16603,68580,171697,82375,3240,184698,177714,137577,201203,194936,189732,148114,58716,175595,88683,136078,137128,180418,183652,162067,24967,113468,173464,190487,65162,126334,90636,173987,177286,168140,100893,38968,219947,216645,146366,160671,3052,168663,41137,132531,123021,129980,43359,19625,196982,128882,84052,63148,86742,13199,132888,58381,144309,99257,105716,61603,20664,120348,116441,150217,117466,23073,130667,117196,82853,31118,195359,171277,4658,26907,105759,108092,9393,146048,143942,44,216507,23827,207561,158388,162284,167126,79848,133257,147412,54224,171598,176022,119592,89352,95248,14227,58436,119133,87024,146840,201173,162465,157273,124029,187867,172272,95235,187672,118628,29890,25542,199051,172729,58129,575,89484,93097,217549,166110,209692,76541,143573,136835,98013,184050,27871,12000,43032,129085,36817,215440,187919,194644,189152,107272,28039,66660,160131,51359,49163,31652,102995,56515,119811,216151,204844,49959,187901,88587,28250,198482,108650,12189,166000,219536,176197,154353,32301,202413,118097,119901,172610,154485,132172,35294,130764,1137,139031,192994,12732,113458,219890,153385,126597,133256,204755,46452,181355,151664,45521,191494,213711,33654,27146,101556,172365,157429,158901,84369,132149,100837,109750,181538,199851,65074,71721,104299,89878,77115,118710,210972,34962,36223,108198,157948,183504,196195,134294,78358,183892,143604,82289,186613,193152,145118,165827,56975,52661,54974,163621,187271,76173,179991,90516,176258,34779,186226,170416,46612,132092,142210,182009,166953,84510,186214,30252,199432,92337,147398,125926,152779,149589,174891,186154,109439,176748,196049,139959,74954,110726,211892,3347,137030,27910,6624,99842,38096,13910,175703,14838,52403,196656,205922,42188,75134,179146,65683,38385,16261,181035,205614,53477,197550,141865,8468,94026,213392,118644,26864,161509,177085,193853,147294,58800,194619,169638,147133,101944,62284,133935,205601,208885,209406,182485,75402,181649,174193,12315,100029,101595,206598,179607,81960,144736,14103,3398,194948,65887,53238,162676,204714,116146,217575,59126,166143,184714,186003,169772,96401,197946,154241,143291,193580,158174,51503,203462,190031,50724,77116,117654,46511,178954,19521,49014,45559,192696,134701,31287,98880,10652,112330,73167,145908,69268,34096,42186,153088,66850,1479,86655,122245,186040,40581,10470,40292,84575,156468,13331,202810,163936,161378,173510,78945,128594,207069,150123,144221,93869,19105,205529,185514,83324,137729,57661,48200,136473,17872,206112,132419,41778,109033,142365,140780,105753,24102,91593,57985,55929,173889,87001,86787,94580,76958,57264,197304,162078,139067,124480,147150,215924,197167,170210,2307,30073,175358,90703,116828,63005,168365,160878,162825,63917,10453,84131,99974,208720,29720,101155,67263,141619,189785,74141,6476,135493,215510,201980,98293,134823,133361,116800,127612,11495,77708,170794,50465,83968,134714,21586,25154,43881,202720,141026,139776,1921,164346,17163,55255,175724,165023,202789,55882,179674,111203,219064,27243,53891,140969,217251,184219,194347,203215,114094,169492,214107,18905,195792,179842,40305,6226,120383,187922,86957,9931,22943,19201,14226,46659,208198,66036,146298,18465,150722,215382,59675,67803,106858,101720,118259,165207,106023,113514,82943,4427,102105,176501,31013,140203,1546,168145,188818,162096,16577,198170,195546,153513,183732,19028,93768,131815,27822,75978,206278,179798,78997,159007,22676,73798,200612,114734,99985,103576,215514,173204,8079,124253,40979,139845,57636,35751,195757,102496,208868,209055,143058,77326,166211,37768,78133,177547,208408,191188,169853,97813,2922,145103,145508,37060,32305,10950,1318,150134,156409,103306,140811,143530,22315,79872,55064,196161,194361,190300,90138,55505,108174,169327,134362,37343,42193,48232,57878,161712,218824,66254,50941,30047,47541,175892,209626,148712,14257,196881,143878,120917,168997,144091,19174,76192,17658,65716,25743,51751,190208,152199,127714,86730,93621,33610,175976,171828,61481,26226,74746,162402,18299,83305,127365,118134,86073,161570,172863,80130,181785,153136,39949,148256,148599,95549,211733,82924,112278,44207,1266,82107,64533,205956,58071,186698,193938,167991,114327,72529,94237,35207,178147,192144,86917,125817,121332,116747,94041,203100,80077,127656,144617,28529,46077,169227,159023,21488,73561,216647,36700,141195,51493,69005,192294,164264,20044,190690,19765,4883,132953,165577,8177,213029,152539,151755,106203,201641,154674,55773,9609,144597,68319,142467,141473,123759,171418,203213,38963,95897,102891,58538,154560,192211,76283,35096,121818,131366,133801,23738,104014,95385,131640,2776,173717,62621,211429,160965,179348,43474,198497,134777,36797,116974,41739,46961,148467,148324,172861,38325,192916,191890,126747,94056,10595,57362,127832,60751,16478,69299,96845,61346,11221,175265,54204,135760,207476,94699,101736,123122,119182,12173,10460,170006,84953,27059,184480,136674,166085,217487,71695,218856,194487,69026,151684,142698,149464,45573,98739,171357,99226,95760,154056,171268,219013,20404,135440,65951,98652,58431,135363,106389,92510,214233,90448,127627,207790,126382,158390,1729,131553,36429,115316,43412,60290,20208,134883,168683,73612,55447,39384,48996,41250,94954,157816,188036,214159,31019,173305,60383,108688,212644,86566,188810,29711,126226,177007,125518,125572,55234,154872,42269,169675,107714,5884,63876,11310,33524,150336,155280,42239,35317,165030,134386,10273,150016,38499,43605,67701,47195,135434,106287,157117,147745,167148,4881,188625,190834,173472,215299,74305,22478,56778,117755,158997,123578,126266,44713,57504,217674,107680,157506,37411,208303,152737,215076,201192,47679,163371,172922,43967,189841,179723,89829,118874,156377,154821,18046,60244,99145,101046,34658,28461,4341,53044,137025,132214,179402,110046,216665,39206,158731,25244,46510,124746,161152,157237,10233,205486,139311,131944,197232,30806,24429,126416,172855,34164,158197,141370,216929,314,203127,177880,170192,113114,136850,110279,97092,11387,182882,138888,112448,60158,210318,126385,206977,201502,99938,91886,154415,210926,194351,199632,219464,182637,151026,25615,111539,66720,125079,200765,60951,21972,74256,156838,142535,107124,67582,49550,1433,185899,194942,206942,1864,151515,136980,30888,136252,98311,171585,53849,180293,210839,82811,107176,94966,27940,40574,111692,170806,153451,66488,68498,140536,111034,148024,95851,160400,72236,178416,104190,197911,53397,108325,144730,205465,146161,126476,46850,97049,146563,145466,126814,66526,80387,151427,90270,202950,184941,104367,192880,36193,31217,196035,60233,117072,205287,183320,42413,26287,56218,133259,105004,206066,88719,209939,144330,160963,182729,179454,185280,13645,42136,102001,166691,1386,210860,23878,215811,123645,157350,114131,23105,16798,68967,58135,97212,20981,159653,9897,189921,100748,105026,160947,172012,136578,75096,186543,156623,15063,202796,215018,45946,167848,129595,115073,2807,52876,172198,210740,135830,74974,51051,8810,205510,157159,133837,135217,183094,60731,161268,44428,176691,9453,94303,62719,171709,75,41273,22765,72335,204386,104709,92570,175547,103584,106286,153637,52510,4028,193180,184450,23636,144072,74578,190352,146568,152812,91251,91088,213782,72484,59322,158954,50993,119212,105894,150708,1219,68447,45884,106727,132707,27909,146441,18975,131902,77136,89120,30673,210486,134555,67552,141008,213161,207982,150704,65573,181978,116069,98849,147631,126780,201842,63772,198567,89040,203318,120284,178675,120775,153833,7169,126345,30479,67585,17708,100315,206741,57771,177448,183809,60274,209757,82766,108114,92562,135398,140417,2548,183715,218150,69035,65079,75311,24381,104577,176605,196290,215239,89651,19930,145372,46788,124926,82279,92677,57295,150395,70536,47915,195600,179072,49851,130809,193229,52341,61567,192788,63453,69574,54205,60085,53154,218874,65456,219240,162161,93330,40893,27967,31425,20575,177160,173912,129149,160911,1736,200342,141944,11079,126429,97306,162704,82684,129064,150606,72019,128661,200346,133990,73470,104301,58220,196652,151169,33618,14194,137818,209523,110933,125388,141759,8536,131720,91039,103210,179919,203398,45822,119851,31553,183040,135844,120604,82746,180672,2206,3098,176712,55247,99464,152307,26410,85311,94791,160636,68078,143802,68904,62164,70623,123560,160716,89940,95924,48211,130749,111134,55189,109862,183687,197183,97598,95703,105012,73303,118155,47156,149791,34377,150966,179833,59917,213435,26180,188957,210786,61814,145579,113303,99846,51396,38232,209725,193604,39203,123990,211201,2553,46296,110499,130966,35250,165321,135530,139980,188890,205870,127564,78652,65791,147361,148854,12506,104438,25945,181021,42819,15738,56483,213337,124921,165114,117926,195190,87728,68096,94624,94452,146892,146143,212340,67558,90214,5199,107940,31781,92036,178149,158748,52155,151218,150967,170358,170165,141148,45754,204553,74608,93089,85220,176473,133809,127793,203094,193075,117160,20554,99895,160868,75110,121569,165599,37971,43928,167924,211738,88444,110576,165484,160513,149851,106108,18702,20738,42467,90079,59825,82177,83067,171183,76377,186235,158698,168989,103666,71106,116627,204727,95201,196056,153857,136560,118040,43861,48204,6055,32778,190781,61506,173545,65967,195002,183683,196539,20988,55200,43754,104083,27003,25048,16722,143768,14110,6911,202095,104074,137358,200539,21823,27244,68800,162031,40704,22101,171660,102195,81099,60124,64853,193355,74494,116171,164228,33138,34413,138082,219812,44723,7489,9215,173945,93543,85249,124821,122471,141710,71411,143797,116613,35419,182877,141371,46409,154882,156653,119680,165425,196322,166684,147845,114026,137221,167716,167648,139928,81175,91008,126262,141470,59540,210949,23476,115831,80705,95887,108056,68106,40820,79579,2767,24,137184,213746,198782,35730,90147,75173,122997,140272,2180,180294,128956,198042,125060,78541,2394,200768,112836,209181,76787,156500,60611,208867,1892,143710,97948,49432,107934,146945,114584,103481,65508,42800,187416,193231,169342,102163,217167,99184,57826,68938,20764,110687,163924,62799,135426,200419,209704,160750,211700,181908,179474,71267,74626,169989,194753,139445,212024,70470,108240,49283,17947,47816,35815,75915,28929,166957,161934,118690,71338,183624,163251,128973,49227,84054,5084,39885,13010,140239,5801,48742,26810,73954,70567,187484,114928,190606,2150,61029,23644,39500,3860,129388,56934,89630,110164,81677,205973,209574,130651,97359,121552,182697,207467,86572,21030,196757,38296,14690,69862,203332,102871,18713,123246,47606,50382,65556,214237,100028,217121,3307,31971,201364,187349,65263,100872,202172,116119,64792,193362,8784,52213,120805,192760,202111,24631,136613,174589,54643,123948,102006,178636,81817,84689,41579,164931,190902,11058,195225,144069,136557,32018,126151,167384,78332,52094,94483,13654,174770,57091,29603,61857,165913,113812,85069,3949,43872,208244,72377,153274,31365,206869,63724,2340,11008,152325,205383,181076,124589,180461,156293,40630,202059,23456,65201,41339,131050,23960,179816,162174,88715,199114,163874,15742,15669,93196,168971,40348,140601,207478,28152,32984,127885,60994,83031,146368,206716,168990,198363,91047,211163,115007,192201,50584,5597,180682,99029,93891,189896,35319,65718,10351,127219,185237,114688,146544,77257,142274,182724,128474,49748,148557,48923,142506,91138,84850,191994,122786,72174,217933,156940,145364,42033,18055,139234,40638,117803,77508,214805,141978,96236,119034,188,203402,167254,132633,99363,192323,48636,4907,162459,95440,203428,162581,62596,84068,189006,16604,172566,141284,214708,101578,49773,179103,130429,200732,144563,40404,160704,74220,84567,107427,168009,84416,28957,25233,137153,172724,17479,11475,218940,34019,46829,8198,52973,28063,127236,50231,25887,8340,149687,168231,219230,101868,172470,128426,137840,163570,183845,154743,67092,12867,96037,57861,158847,188178,18579,3430,54719,171900,128819,218124,20960,122698,209982,167488,15912,94614,146281,194898,109039,215489,208207,185711,15283,215913,161725,69681,138679,10656,192492,18451,218747,79666,20430,56596,136179,25459,26535,190099,60559,139371,101012,49665,50625,137938,172196,65211,149429,163925,46897,146387,196396,146313,108838,195587,27659,200876,167411,151142,93837,176814,120250,184169,174389,111955,27289,43728,207739,47289,144170,105265,132302,88476,172368,41465,139958,157264,116132,92359,82615,144200,85106,97444,108093,200326,112876,216571,190281,84790,159491,63103,147625,143032,56079,98320,200725,144241,195186,66607,94374,118594,117309,104094,51,190398,9086,74575,156175,4273,135688,35411,79125,1721,78730,58166,56314,82550,5688,202250,215466,180805,215570,72340,145139,134026,145836,149597,149844,40557,24716,9892,120683,93233,40922,36282,69465,96228,152608,65498,212618,78722,24092,149549,104804,149176,56999,192483,20237,189682,85116,207855,126493,45325,127998,216453,180773,184247,151055,214831,24999,156036,168363,78746,9796,191572,211901,151069,100020,20572,204504,14659,202087,87658,91207,214208,113850,109188,31598,201953,59432,195999,171122,157889,166169,203795,165725,100266,14189,79897,137463,10795,141539,148088,39211,85676,131701,182915,57880,40159,13243,176589,191778,102991,116815,197352,156010,150799,132311,58218,113229,75964,206141,103369,38800,34652,88124,21638,124742,149799,195379,12926,143694,104645,68827,22302,104471,131549,113899,202853,182942,34177,199762,35355,126070,146616,51932,127889,103630,170983,149770,51406,147106,161479,192850,137428,64551,63521,18781,69797,203626,112739,84780,179499,150999,23486,29499,176094,108418,111028,95752,91306,49935,84079,90315,155681,171811,6345,201767,79906,44369,91676,161538,215373,210448,163363,38225,115464,11764,214478,61207,52346,179662,97994,40892,198635,190968,22495,29457,193876,14296,63713,162792,38712,218220,214993,75414,3600,89986,181874,210837,29170,79984,117980,123519,5963,86999,54790,193989,141034,55541,53425,143241,137171,196701,67805,143864,126622,163326,191345,37053,182208,183174,175180,4820,98725,152831,173939,138097,11398,66966,44648,15757,74452,52611,206881,182590,161995,71114,13915,121886,171640,2184,17458,101880,173112,71122,173796,173041,94429,28283,5001,76675,170239,101324,60599,99798,133310,26110,127051,156527,139345,23074,48868,52133,131242,120058,146980,186769,5633,10420,64852,79934,87443,128432,103127,195137,41157,4816,90906,179388,92251,93160,163900,79253,61919,131400,72319,212832,16255,164047,67549,101494,213610,7724,145474,73451,39769,89597,102982,172210,35982,23014,191336,100193,136503,177732,160549,163977,179967,59618,109401,63009,150181,40042,111795,81889,158693,86320,117747,142143,34543,54200,171041,36327,152313,141912,179319,209742,53034,36204,81002,180494,184847,137312,35455,112738,201863,38050,172027,148070,41431,83229,17777,146103,139712,188597,31034,48269,83256,38147,190599,200654,83542,108538,107807,79520,116130,195457,215403,166945,193551,71774,39834,204238,49265,27989,34483,48430,6867,199871,160427,161728,69956,158755,148646,160527,60532,156332,60865,180191,6837,94178,20834,180392,204426,124938,35150,131681,158521,122883,186651,170280,48228,101570,128829,124216,36048,42935,50661,59316,8448,107109,3935,70741,108277,58341,194606,175987,53171,198646,53624,17337,46419,115155,161758,128435,180319,85169,119942,64692,112081,106133,203480,10964,159365,192762,216092,99712,188796,82024,103324,160302,194633,149929,140550,137686,168137,184578,206286,10095,136485,28656,118386,119773,216795,11119,64414,135974,140999,103707,192720,10289,146197,101554,135299,200582,165320,12922,49768,76330,115226,104343,79837,202814,216955,74950,120939,43643,110745,77180,160729,14299,180867,135520,37287,121524,217551,169371,118409,23708,189437,131244,29667,47125,204139,153471,182482,143237,35469,103040,84793,118197,6530,46891,118929,154747,186812,52935,119192,23640,65557,84954,181329,170020,186507,208794,42990,188317,192220,32386,23905,29927,185133,117807,166218,95143,22362,137315,185408,100686,190545,207373,134130,160451,41916,99268,144269,99773,11509,201394,44796,87868,148240,172891,39345,42499,37907,46888,201629,186658,13450,109990,80756,81138,191407,201491,116983,75702,164220,47368,30816,34490,53176,926,163200,200643,66351,3200,120230,130450,130716,142475,211669,1507,159379,129359,62358,56852,154366,172876,124705,118944,107518,115834,20815,157605,157313,1723,64150,143482,72905,162467,164793,146158,214989,216334,12823,48454,38864,14434,171252,91493,9626,38109,8501,174849,151543,219680,158147,78234,86798,22990,12853,28544,109302,46726,42293,176184,74548,160193,173257,181276,183460,155967,150020,206387,147371,46509,119364,196433,151772,94421,9395,24412,201460,78458,84625,96618,99336,215027,17738,198821,112020,144302,17019,76764,123421,35270,31987,182090,71937,86300,138163,108464,137532,55009,182407,54801,160446,15252,28652,149780,191293,164882,138660,22982,126789,189811,180526,158538,206617,185284,75284,111949,159604,82965,68068,137728,150770,92760,34003,161032,147491,130127,136237,191151,165637,129107,211807,47161,159752,61006,185751,76149,50739,414,30987,6287,121866,163960,208858,30913,193322,54350,209600,161299,108693,11925,87501,91658,146909,211299,167117,55663,2961,167934,52740,180883,73826,66694,207145,17782,80680,78717,132825,144177,84914,201072,1607,29841,179787,87368,89485,103821,149139,161478,7632,218115,46040,86642,218582,88588,176284,125595,124387,208995,9321,169253,35587,126221,40733,164416,120748,140762,87846,121233,205518,213428,108515,194664,14685,218917,136596,78471,148096,190045,177785,82707,140479,4906,54594,151361,179697,86958,209593,208609,92692,17357,39139,73520,8171,88906,98367,164236,140056,43507,149679,158777,102122,49120,156025,30630,205086,197011,80814,184881,82444,186080,8291,71171,108493,22753,130299,62296,97408,118464,217437,79332,56659,137651,79250,45878,29420,98098,205768,9802,52978,24275,116332,132643,6536,209459,126527,131086,40390,89159,126739,192005,137265,81358,203301,85809,37563,201714,25756,215518,135304,177397,34791,53788,189941,9582,120420,101603,187086,136929,30365,1077,11542,180106,155788,22826,87369,122388,12095,27027,136466,200849,195442,81914,5893,151403,198834,81826,167470,84116,52574,50474,75121,72522,139532,83663,202581,115635,175931,30841,71474,39143,41286,71732,30005,3999,163301,146702,134757,26837,54153,142220,91833,144012,83488,71054,37108,184642,210764,150492,64179,44898,18675,81625,160763,88798,141709,116299,13297,34098,7600,123400,177077,13226,27768,100558,22575,102738,46919,184727,156269,155387,165011,219198,180912,183334,116873,161563,93224,81904,87118,214287,51943,186882,198810,175864,212153,139849,20590,198113,200390,122079,148842,34549,24357,213084,210331,183133,122755,25496,119428,81367,179431,168930,186178,200931,123350,60737,58666,37675,161727,67646,95276,102780,48064,183399,174373,86593,206773,85134,118825,127796,7933,17135,124606,29592,92128,125475,120911,44774,73992,201311,91387,158450,150648,33395,133155,26983,106348,214316,187242,208865,28044,85424,145791,170249,18592,207731,79198,139656,120841,105071,154805,108381,62541,162758,63882,143641,41686,126451,100615,17317,89316,54588,213363,181943,154919,113022,205603,114125,148691,206392,89705,142744,181962,61906,45373,173699,175532,20419,193820,186622,150520,219348,29969,155232,156257,121348,85295,145512,25599,215488,116710,46182,80749,106376,18247,156830,28506,179701,213402,156344,202203,214922,99329,85828,129576,169080,173058,16117,23732,57045,214180,36416,152244,197915,166758,212294,48,198463,38463,173246,118744,139265,76415,102231,161368,117220,51292,53912,178655,82994,78573,149494,208233,178142,79835,113961,188164,90246,152758,2128,160298,133804,105957,33437,77118,16360,96789,143886,13927,221,184380,146808,215155,94173,28756,73979,8174,61667,113171,184083,154707,103355,163896,83395,186587,100346,106439,100024,114082,142899,55369,173667,45458,124484,36456,49342,145600,100819,94447,109375,40440,75620,175653,32911,35200,3041,18807,169971,4117,37747,189094,127585,172471,114487,35592,22194,35971,168360,211272,109014,119545,206812,158447,111595,163380,47132,119116,133540,111181,144716,214976,117976,100850,93141,59176,48704,111351,158485,1947,151503,130588,191604,62062,186680,131525,70582,107698,143174,151087,96198,8771,9926,148937,126685,184780,96620,29665,152113,126230,181009,37247,191413,191704,137994,196760,204176,73248,121880,135515,20508,137194,159926,23185,64285,123850,165324,194428,173994,175300,20468,179325,166117,141347,209216,151946,81919,13834,90058,140255,204086,2826,122911,119022,134391,123006,150461,192365,86904,110366,92514,209249,77232,75814,43911,9260,182203,156666,196457,79710,190571,205788,147772,68273,16100,164967,124036,151556,44323,150171,13531,172050,76481,212047,193299,163660,7156,75187,127076,21629,33245,132538,20148,205458,72836,66073,34744,161427,73655,72146,67359,3723,158288,45773,155128,193715,39619,172793,117087,96932,211770,45979,68485,205134,183703,61917,127118,47398,209550,138251,8674,89875,22125,38029,156118,215409,40409,219575,130009,197331,108251,100444,62376,173189,115391,190316,167136,43852,55274,151455,192508,90459,14241,168364,69309,198045,214835,30407,13768,43935,121638,104528,124841,79001,124850,195598,197244,139597,61252,99562,179178,184970,102254,68845,139052,113962,3109,75705,209009,117655,100140,177862,82755,139002,179011,10752,25523,139275,167096,92753,113190,160035,170935,71457,87809,70864,70015,210322,25762,66936,171532,102538,120876,202676,177023,196369,63385,192495,176131,154330,201981,51098,136302,88449,177659,47164,82121,143553,2772,176370,218022,17813,3135,107806,25611,137694,27238,179699,116901,92147,57518,190120,74264,66033,50614,8923,6212,162371,24101,14911,155749,172125,219010,56839,157125,31926,185617,31451,158514,34054,98555,172069,169844,148391,210554,48435,29057,182361,48303,66410,77219,122409,173140,19848,167452,118867,72452,201682,111931,181461,201819,201882,84778,67459,204664,19397,72689,203713,16354,100348,101866,87203,15744,67695,213070,69380,46643,115494,18192,128901,159986,102926,46777,169209,128989,28424,37715,141723,201512,113050,57248,27852,115757,187537,67894,27076,9138,19115,18458,149774,190721,65692,203748,116885,4919,130469,39021,171144,167860,227,9918,174222,102362,69844,39742,183515,86463,102584,150638,970,185494,202110,128628,125341,105883,190779,204717,212933,205883,18376,141591,59204,867,100774,180685,214427,126370,31772,140326,94088,87409,178739,79384,105018,14468,165746,169626,23767,95775,140149,89145,21091,120491,199348,200935,76010,169814,53246,46299,145352,10310,100706,173207,50070,78347,182421,102263,27988,165312,180575,127533,22246,165882,114814,49666,36685,9358,50618,6649,188675,155437,62517,128249,101913,16513,207811,141702,168665,80669,130989,139969,218693,212802,117783,4303,9710,88941,14694,88530,206511,189008,11223,106301,30982,55490,211577,171402,174932,137132,120159,177754,140656,39047,194262,151872,61075,18883,9667,192796,37117,128160,171398,4997,123527,162102,12006,63617,183840,170075,34219,3765,45154,165848,13096,80566,215151,25504,119586,160229,44745,218982,154938,16766,214381,214294,27874,29931,130323,155969,209498,161383,57821,16543,210882,53121,111728,88432,212735,219125,193443,177837,105384,84996,31667,179329,150986,143100,36000,84302,219264,207906,55897,200337,11950,155473,45468,87950,16573,71799,37024,198587,108830,44053,182529,96974,145072,43255,114992,63869,152564,10668,26205,207369,165726,160547,98600,203859,41354,4097,127655,21160,1664,1893,47032,23367,185267,155642,74378,109110,181797,211658,207681,202522,124783,181497,33264,54043,37509,167726,71027,137124,118074,108011,62718,45756,160048,181318,9862,156309,47186,188009,131042,160445,164414,216657,205149,78957,91187,109782,1198,146200,61771,50306,30803,138988,25082,162246,94870,199525,141485,133048,57481,121691,201937,164223,172838,54554,72463,13911,50782,29538,208827,179632,124364,130420,46222,41880,175673,11302,14163,169046,95471,54942,113946,65898,87633,161695,211066,198812,118789,11355,27373,194095,13751,140865,41430,205912,7619,127155,134978,30425,18575,61682,189111,116607,155374,151981,85451,109855,52101,195844,177432,85376,60309,204203,28870,85493,59711,25706,62159,193396,108159,52417,98443,57845,123467,153378,126984,14244,114259,37995,56666,159759,119273,46677,101536,97906,35757,119993,188696,50896,165680,94002,70698,53008,88051,33347,8413,181458,149072,15637,73558,35783,29126,122037,120213,109634,199318,90066,57633,77374,158239,41101,206974,208290,167247,160308,44547,151551,144901,60215,22668,107923,136778,128244,56739,155079,124122,219411,121310,73182,3802,191369,112226,192984,64658,64967,85798,113835,151903,58302,155129,39051,150392,1331,109516,11593,10501,123378,192276,201485,132177,8813,111594,139230,103368,125637,212262,209195,182623,215296,25689,118604,129836,63048,21781,146404,197415,3256,181707,191833,106592,57149,34789,93953,77539,62668,169774,23784,4175,125772,175383,79911,42415,184367,120398,149521,195896,153002,6399,119233,202331,36339,184322,106262,48482,207056,97103,164314,121712,184523,166042,357,164444,18533,87269,8575,77749,111647,197365,70723,182990,160563,142196,150792,208267,54302,206488,179480,21222,162648,88000,192875,14368,134935,176856,95417,13516,29089,100389,95837,121989,162302,10714,95832,96058,45945,169914,146483,184778,17565,146999,185948,202788,53664,87799,62556,141461,29768,170141,125434,197204,152074,54996,197439,195301,199461,59097,84007,92461,70086,132006,178317,110868,72651,14276,16611,97730,90499,118596,127461,131452,138066,217464,32167,123250,186076,139435,216368,139684,76615,118483,162778,96507,209689,176503,181724,212020,17203,49375,120028,22887,21016,102415,46874,75595,20171,28669,43192,72336,29366,182735,38785,130109,39452,53544,136495,164436,5948,76852,109192,178887,141933,34610,124231,90885,110134,116355,193386,5631,76419,49561,16501,5776,210236,60903,218876,136215,168996,88048,107922,151703,89860,171903,169358,206261,85539,127440,155545,174549,182678,138979,143777,116721,120063,79713,148701,34591,206922,80345,45019,206901,68159,59449,126941,97491,75535,201785,12404,38886,92568,74549,162170,42150,186597,140330,74775,45969,185627,181628,174681,208262,40235,42271,62365,30649,10787,83001,150608,33154,115576,68264,96907,48364,53780,147452,185371,9093,174470,83084,90137,38782,97757,117552,141522,42623,139891,79052,31847,6533,106053,130919,201462,38111,120072,83943,40908,1125,193592,143168,133014,86843,169671,184929,77478,124232,75976,197698,150399,174820,38557,209604,148172,115895,88784,180655,29442,191795,10506,36548,177735,144430,30638,80761,6500,105284,161879,57767,116909,93688,24243,85555,14426,159423,51157,152943,138001,58233,204808,201089,72856,51305,5910,21912,141141,116774,177488,176044,22905,134162,103395,50973,128572,83298,158117,158617,134101,14692,174788,94650,104217,38783,187685,112103,96896,101217,97735,3195,74619,90418,92105,80762,116805,147071,22963,56543,190802,116449,194387,108949,210681,115999,75663,42936,211202,190841,111105,190999,108578,36586,219024,153925,156745,22564,37654,54469,79490,124416,137704,87927,154842,198210,98506,23186,1735,107337,167459,108662,45847,145703,121141,186186,148875,197982,35894,52234,48760,218178,100078,75881,145192,128604,11071,215691,216596,79165,218781,153408,79669,129525,146612,193948,121491,58856,7035,207413,641,40376,68859,78687,167917,218691,166870,123873,32552,171393,174565,75415,109355,47021,128945,180313,169026,64159,7456,31288,200466,72670,129250,17491,156995,195445,94833,152322,125956,148261,114312,67609,164625,179807,124791,181992,214507,31348,134355,192140,150801,4599,68036,42689,152477,15338,38308,83802,120859,181739,35519,93686,42810,13864,71478,94676,109973,166788,93319,12770,75984,76457,138834,167495,20139,134707,50807,200918,164634,48037,200458,156961,197297,43813,142530,84569,82892,216746,179560,209683,78924,71161,179684,182967,180445,25122,140095,152055,92177,215827,37674,3381,35900,135835,73771,97783,204159,127113,89211,42737,26848,57114,131958,57963,217992,49392,208432,146448,160806,180788,217950,137796,139485,186789,144470,184094,198261,210391,29875,109404,120715,20027,30556,214655
diff --git a/data/molecules/val.index b/data/molecules/val.index
new file mode 100644
index 0000000..78db245
--- /dev/null
+++ b/data/molecules/val.index
@@ -0,0 +1 @@
+20952,3648,819,24299,9012,8024,7314,4572,24132,3358,22174,24270,17870,2848,19349,13825,1041,976,3070,7164,7623,16559,19726,869,18390,6515,23462,21295,22981,17856,13746,7223,14719,19309,9115,212,5231,22876,13848,11149,9105,5094,7055,11029,3349,3039,12449,3169,11763,11270,19782,8667,1423,23911,15054,17571,4090,12403,2582,18089,9606,20599,20267,11850,18918,6300,23087,2279,1501,21668,7467,9482,2614,7628,3309,12455,9108,14857,20830,11954,5329,12130,11641,6865,21960,8748,22997,22398,21234,2339,19960,20806,5607,17502,23892,8021,5354,15147,12433,8845,20971,22549,18250,7196,22433,10626,1832,7505,1051,10336,13145,8773,2168,6913,18585,23524,10311,6967,21477,16358,12964,21064,15035,4681,8679,4575,8081,24411,18394,17661,8609,19155,14038,19121,13087,11861,7186,4532,16696,16171,2978,1543,3592,5008,20560,5242,22298,13833,19543,2081,12608,12504,19526,15337,17338,8238,18128,376,22291,23616,3753,22338,17595,8743,21003,11146,3655,9617,14246,5182,14867,106,23661,23582,8630,16403,5854,16635,3486,20489,9779,20937,19954,6517,12252,5293,17674,17378,18,19626,10621,16010,638,3665,11894,10076,7846,1898,7892,18591,2580,2806,23983,15924,2267,17455,4120,4207,21618,15574,18015,5410,8685,17290,19876,13865,6940,17671,23918,22605,6591,23361,10214,13074,22009,12236,14355,16959,14794,3965,8123,7362,2098,11078,689,19277,18150,7540,19282,7216,235,2326,23194,20679,1929,7501,2208,1029,10827,2321,16847,7798,9125,21921,15906,7020,17669,4335,23702,18711,18881,15488,7962,15498,13338,6239,3090,3176,21593,14124,11609,13879,13470,15303,23890,1775,22064,21412,21174,3224,1986,13193,23862,11118,3580,8147,6278,6232,17573,14700,4593,13824,6012,9127,15159,8185,2470,14520,18033,3208,1657,21369,17713,483,3056,7745,5449,13317,15913,15773,7004,13141,1921,5394,12418,70,12793,8690,14909,9347,13861,22825,23937,18211,21688,23540,15947,5072,6222,9722,7133,1916,18978,24108,17766,1997,10276,1873,1643,19142,15623,16477,17403,5158,1863,16640,2625,6089,2245,19498,2226,22125,7707,13230,3928,18667,8067,18970,19481,1302,20295,2686,13737,21540,19125,18521,17130,10366,8544,6693,21945,23468,10295,7821,8703,12969,4288,21151,9830,14982,10360,2377,305,15017,20354,18448,3276,2400,17617,6984,16576,4340,11436,2254,8004,12108,9338,5169,14358,17800,23053,9912,20043,21429,17332,256,21884,18173,9810,21737,3394,4400,8666,3782,3507,24327,5093,8924,9232,19819,6901,23514,11235,6671,22527,20782,8650,16561,16008,8228,1664,3024,20784,9066,1444,116,10929,4286,20876,8583,5294,24288,14478,18077,23123,14014,18379,316,2465,22643,4884,17877,1180,12098,19087,18105,4852,14083,4176,1370,10101,11948,1307,11724,6883,22349,8176,21854,3368,11589,18346,13316,20337,5064,7757,5324,5801,13510,812,5877,24135,10885,13491,21951,24086,8131,8742,5216,22979,3542,12535,1268,15423,7288,6539,15083,11457,10000,7457,7304,775,21627,6328,13056,10756,9129,2274,9146,11506,21020,16692,13096,22266,17570,10851,904,3779,8559,5851,19024,8698,1253,3552,19548,14239,11327,23872,10278,14299,19864,16758,3789,12622,18893,6228,8346,1454,23225,14288,55,17036,17643,22506,23574,24312,24151,21975,6456,11934,14132,2292,21765,10819,20419,10286,4083,23584,9840,16617,10134,21852,13382,10688,13185,22846,9688,18166,4170,6286,13777,21788,12423,22194,5702,20169,18648,9861,13306,17954,13,9957,9401,6887,14086,19004,19879,21453,10559,15236,14476,14488,22138,7002,16750,15505,24115,5560,21589,2778,9299,16890,21753,20740,20291,10983,3060,7696,22046,10171,7361,6525,4828,800,1514,8023,15569,20030,2386,14923,13580,20636,18863,6371,23538,22818,12582,16199,13095,7994,4835,21497,22532,181,3492,13931,7170,5763,22803,16972,15222,1645,18265,8165,3976,14957,4369,15225,21875,17404,18314,19511,10397,14502,20075,23569,16540,13983,17952,14611,5215,24368,15554,14747,8493,8101,20894,9086,17081,15879,20537,7839,8998,14413,2538,23381,9362,7683,8903,11005,10476,17699,2640,4534,4942,7577,12551,22739,5007,23147,7010,2104,13594,13356,10842,17780,15267,13624,2040,6777,13767,12762,19139,22790,640,18864,12464,15629,193,11526,9784,12779,13730,17636,24072,17895,19767,7226,15998,7190,8943,14281,951,12742,11014,21917,22254,13248,23729,5408,15315,4182,20390,883,12911,19395,18493,21725,888,2750,21061,14044,4446,15128,5954,1647,8524,12422,10726,6935,14899,10710,11059,9117,13813,8266,2683,15411,635,17675,1706,11467,7347,21303,2248,21356,1319,1016,8102,6532,667,20359,4993,7816,4136,15517,21936,3748,18480,7142,15238,22920,8396,12087,5498,19853,19898,23539,5366,10192,18962,841,10222,18867,22195,12298,12997,23429,6498,2490,19401,22630,20553,7957,3339,22845,9882,22421,19674,3966,18544,1345,11377,17456,14037,21676,12142,2259,16579,21218,11181,414,13764,16062,3458,14205,11868,20826,15064,23177,5013,14270,5771,24045,17096,21314,8850,20182,17634,15843,15232,14272,23954,19414,8794,10561,8044,2839,9139,14771,7990,15227,18672,19999,21895,11023,940,16197,10650,5958,15976,6950,11626,8465,11152,9163,19534,22976,9052,18212,332,16928,6260,2805,7908,23595,16009,18192,7874,22629,15600,21164,23325,16083,14685,565,3049,9641,7261,13251,22668,7972,10033,21756,19056,12092,15507,18136,17397,11264,13942,24442,18035,10839,11528,23031,14868,8877,10047,8237,7554,3953,23635,6310,10339,3917,24342,17559,22615,6066,6276,7090,24202,15866,9060,23743,19319,17191,19555,9273,3294,6360,9707,7454,11825,5879,9904,463,23200,4147,8988,1491,1786,18132,9572,22852,4137,20901,16085,3361,401,18810,9317,15381,15686,14433,11164,6041,1683,8273,15654,3738,2141,13130,16113,2427,18907,20625,22493,1756,4971,4888,18443,9956,2791,8132,3881,18286,13637,19867,19533,20264,7395,17123,14762,14507,9743,19284,14051,10006,18632,20349,1973,19976,24251,3251,6808,20496,6914,8671,21640
diff --git a/data/ogb_mol.py b/data/ogb_mol.py
new file mode 100644
index 0000000..c20ec4c
--- /dev/null
+++ b/data/ogb_mol.py
@@ -0,0 +1,254 @@
+import time
+import dgl
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+
+from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
+
+from scipy import sparse as sp
+import numpy as np
+import networkx as nx
+from tqdm import tqdm
+
+class OGBMOLDGL(torch.utils.data.Dataset):
+ def __init__(self, data, split):
+ self.split = split
+ self.data = [g for g in data[self.split]]
+ self.graph_lists = []
+ self.graph_labels = []
+ for g in self.data:
+ if g[0].number_of_nodes() > 5:
+ self.graph_lists.append(g[0])
+ self.graph_labels.append(g[1])
+ self.n_samples = len(self.graph_lists)
+
+ def __len__(self):
+ """Return the number of graphs in the dataset."""
+ return self.n_samples
+
+ def __getitem__(self, idx):
+ """
+ Get the idx^th sample.
+ Parameters
+ ---------
+ idx : int
+ The sample index.
+ Returns
+ -------
+ (dgl.DGLGraph, int)
+ DGLGraph with node feature stored in `feat` field
+ And its label.
+ """
+ return self.graph_lists[idx], self.graph_labels[idx]
+
+def add_eig_vec(g, pos_enc_dim):
+ """
+ Graph positional encoding v/ Laplacian eigenvectors
+ This func is for eigvec visualization, same code as positional_encoding() func,
+ but stores value in a diff key 'eigvec'
+ """
+
+ # Laplacian
+ A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(g.number_of_nodes()) - N * A * N
+
+ # Eigenvectors with numpy
+ EigVal, EigVec = np.linalg.eig(L.toarray())
+ idx = EigVal.argsort() # increasing order
+ EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
+ g.ndata['eigvec'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
+
+ # zero padding to the end if n < pos_enc_dim
+ n = g.number_of_nodes()
+ if n <= pos_enc_dim:
+ g.ndata['eigvec'] = F.pad(g.ndata['eigvec'], (0, pos_enc_dim - n + 1), value=float('0'))
+
+ return g
+
+
+def lap_positional_encoding(g, pos_enc_dim):
+ """
+ Graph positional encoding v/ Laplacian eigenvectors
+ """
+
+ # Laplacian
+ A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(g.number_of_nodes()) - N * A * N
+
+ # Eigenvectors with numpy
+ EigVal, EigVec = np.linalg.eig(L.toarray())
+ idx = EigVal.argsort() # increasing order
+ EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
+ g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
+
+ return g
+
+
+def init_positional_encoding(g, pos_enc_dim, type_init):
+ """
+ Initializing positional encoding with RWPE
+ """
+
+ n = g.number_of_nodes()
+
+ if type_init == 'rand_walk':
+ # Geometric diffusion features with Random Walk
+ A = g.adjacency_matrix(scipy_fmt="csr")
+ Dinv = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -1.0, dtype=float) # D^-1
+ RW = A * Dinv
+ M = RW
+
+ # Iterate
+ nb_pos_enc = pos_enc_dim
+ PE = [torch.from_numpy(M.diagonal()).float()]
+ M_power = M
+ for _ in range(nb_pos_enc-1):
+ M_power = M_power * M
+ PE.append(torch.from_numpy(M_power.diagonal()).float())
+ PE = torch.stack(PE,dim=-1)
+ g.ndata['pos_enc'] = PE
+
+ return g
+
+
+def make_full_graph(graph, adaptive_weighting=None):
+ g, label = graph
+
+ full_g = dgl.from_networkx(nx.complete_graph(g.number_of_nodes()))
+
+ # Copy over the node feature data and laplace eigvecs
+ full_g.ndata['feat'] = g.ndata['feat']
+
+ try:
+ full_g.ndata['pos_enc'] = g.ndata['pos_enc']
+ except:
+ pass
+
+ try:
+ full_g.ndata['eigvec'] = g.ndata['eigvec']
+ except:
+ pass
+
+ # Initalize fake edge features w/ 0s
+ full_g.edata['feat'] = torch.zeros(full_g.number_of_edges(), 3, dtype=torch.long)
+ full_g.edata['real'] = torch.zeros(full_g.number_of_edges(), dtype=torch.long)
+
+ # Copy real edge data over, and identify real edges!
+ full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['feat'] = g.edata['feat']
+ full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['real'] = torch.ones(
+ g.edata['feat'].shape[0], dtype=torch.long) # This indicates real edges
+
+ # This code section only apply for GraphiT --------------------------------------------
+ if adaptive_weighting is not None:
+ p_steps, gamma = adaptive_weighting
+
+ n = g.number_of_nodes()
+ A = g.adjacency_matrix(scipy_fmt="csr")
+
+ # Adaptive weighting k_ij for each edge
+ if p_steps == "qtr_num_nodes":
+ p_steps = int(0.25*n)
+ elif p_steps == "half_num_nodes":
+ p_steps = int(0.5*n)
+ elif p_steps == "num_nodes":
+ p_steps = int(n)
+ elif p_steps == "twice_num_nodes":
+ p_steps = int(2*n)
+
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ I = sp.eye(n)
+ L = I - N * A * N
+
+ k_RW = I - gamma*L
+ k_RW_power = k_RW
+ for _ in range(p_steps - 1):
+ k_RW_power = k_RW_power.dot(k_RW)
+
+ k_RW_power = torch.from_numpy(k_RW_power.toarray())
+
+ # Assigning edge features k_RW_eij for adaptive weighting during attention
+ full_edge_u, full_edge_v = full_g.edges()
+ num_edges = full_g.number_of_edges()
+
+ k_RW_e_ij = []
+ for edge in range(num_edges):
+ k_RW_e_ij.append(k_RW_power[full_edge_u[edge], full_edge_v[edge]])
+
+ full_g.edata['k_RW'] = torch.stack(k_RW_e_ij,dim=-1).unsqueeze(-1).float()
+ # --------------------------------------------------------------------------------------
+
+ return full_g, label
+
+class OGBMOLDataset(Dataset):
+ def __init__(self, name, features='full'):
+
+ start = time.time()
+ print("[I] Loading dataset %s..." % (name))
+ self.name = name.lower()
+
+ self.dataset = DglGraphPropPredDataset(name=self.name)
+
+ if features == 'full':
+ pass
+ elif features == 'simple':
+ print("[I] Retaining only simple features...")
+ # only retain the top two node/edge features
+ for g in self.dataset.graphs:
+ g.ndata['feat'] = g.ndata['feat'][:, :2]
+ g.edata['feat'] = g.edata['feat'][:, :2]
+
+ split_idx = self.dataset.get_idx_split()
+
+ self.train = OGBMOLDGL(self.dataset, split_idx['train'])
+ self.val = OGBMOLDGL(self.dataset, split_idx['valid'])
+ self.test = OGBMOLDGL(self.dataset, split_idx['test'])
+
+ self.evaluator = Evaluator(name=self.name)
+
+ print("[I] Finished loading.")
+ print("[I] Data load time: {:.4f}s".format(time.time()-start))
+
+ # form a mini batch from a given list of samples = [(graph, label) pairs]
+ def collate(self, samples):
+ # The input samples is a list of pairs (graph, label).
+ graphs, labels = map(list, zip(*samples))
+ batched_graph = dgl.batch(graphs)
+ labels = torch.stack(labels)
+ tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
+ tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
+ snorm_n = torch.cat(tab_snorm_n).sqrt()
+
+ return batched_graph, labels, snorm_n
+
+ def _add_lap_positional_encodings(self, pos_enc_dim):
+
+ # Graph positional encoding v/ Laplacian eigenvectors
+ self.train = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.train]
+ self.val = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.val]
+ self.test = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.test]
+
+ def _add_eig_vecs(self, pos_enc_dim):
+
+ # Graph positional encoding v/ Laplacian eigenvectors
+ self.train = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.train]
+ self.val = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.val]
+ self.test = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.test]
+
+
+ def _init_positional_encodings(self, pos_enc_dim, type_init):
+
+ # Initializing positional encoding randomly with l2-norm 1
+ self.train = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.train]
+ self.val = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.val]
+ self.test = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.test]
+
+ def _make_full_graph(self, adaptive_weighting=None):
+ self.train = [make_full_graph(graph, adaptive_weighting) for graph in self.train]
+ self.val = [make_full_graph(graph, adaptive_weighting) for graph in self.val]
+ self.test = [make_full_graph(graph, adaptive_weighting) for graph in self.test]
+
+
+
\ No newline at end of file
diff --git a/data/script_download_ZINC.sh b/data/script_download_ZINC.sh
new file mode 100644
index 0000000..33ba6d2
--- /dev/null
+++ b/data/script_download_ZINC.sh
@@ -0,0 +1,21 @@
+
+
+# Command to download dataset:
+# bash script_download_ZINC.sh
+
+
+DIR=molecules/
+cd $DIR
+
+
+FILE=ZINC.pkl
+if test -f "$FILE"; then
+ echo -e "$FILE already downloaded."
+else
+ echo -e "\ndownloading $FILE..."
+ curl https://data.dgl.ai/dataset/benchmarking-gnns/ZINC.pkl -o ZINC.pkl -J -L -k
+fi
+
+
+
+
diff --git a/docs/01_repo_installation.md b/docs/01_repo_installation.md
new file mode 100644
index 0000000..2749acc
--- /dev/null
+++ b/docs/01_repo_installation.md
@@ -0,0 +1,84 @@
+# Repo installation
+
+
+
+
+
+## 1. Setup Conda
+
+```
+# Conda installation
+
+# For Linux
+curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
+
+# For OSX
+curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
+
+chmod +x ~/miniconda.sh
+./miniconda.sh
+
+source ~/.bashrc # For Linux
+source ~/.bash_profile # For OSX
+```
+
+
+
+
+## 2. Setup Python environment for CPU
+
+```
+# Clone GitHub repo
+conda install git
+git clone https://github.com/vijaydwivedi75/gnn-lspe.git
+cd gnn-lspe
+
+# Install python environment
+conda env create -f environment_cpu.yml
+
+# Activate environment
+conda activate gnn_lspe
+```
+
+
+
+
+
+## 3. Setup Python environment for GPU
+
+DGL 0.6.1+ requires CUDA **10.2**.
+
+For Ubuntu **18.04**
+
+```
+# Setup CUDA 10.2 on Ubuntu 18.04
+sudo apt-get --purge remove "*cublas*" "cuda*"
+sudo apt --purge remove "nvidia*"
+sudo apt autoremove
+wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-repo-ubuntu1804_10.2.89-1_amd64.deb
+sudo dpkg -i cuda-repo-ubuntu1804_10.2.89-1_amd64.deb
+sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
+sudo apt update
+sudo apt install -y cuda-10-2
+sudo reboot
+cat /usr/local/cuda/version.txt # Check CUDA version is 10.2
+
+# Clone GitHub repo
+conda install git
+git clone https://github.com/vijaydwivedi75/gnn-lspe.git
+cd gnn-lspe
+
+# Install python environment
+conda env create -f environment_gpu.yml
+
+# Activate environment
+conda activate gnn_lspe
+```
+
+
+
+
+
+
+
+
diff --git a/docs/02_download_datasets.md b/docs/02_download_datasets.md
new file mode 100644
index 0000000..e120e4e
--- /dev/null
+++ b/docs/02_download_datasets.md
@@ -0,0 +1,18 @@
+# Download datasets
+
+OGBG-MOL* datasets are automatically downloaded from OGB. For ZINC, use the following script.
+
+
+
+## 1. ZINC molecular dataset
+ZINC size is 58.9MB.
+
+```
+# At the root of the project
+cd data/
+bash script_download_ZINC.sh
+```
+Script [script_download_ZINC.sh](../data/script_download_ZINC.sh) is located here.
+
+
+
diff --git a/docs/03_run_codes.md b/docs/03_run_codes.md
new file mode 100644
index 0000000..ea22d44
--- /dev/null
+++ b/docs/03_run_codes.md
@@ -0,0 +1,87 @@
+# Reproducibility
+
+
+
+
+## 1. Usage
+
+
+
+
+### In terminal
+
+```
+# Run the main file (at the root of the project)
+python main_ZINC_graph_regression.py --config 'configs/GatedGCN_ZINC_LSPE.json' # for CPU
+python main_ZINC_graph_regression.py --gpu_id 0 --config 'configs/GatedGCN_ZINC_LSPE.json' # for GPU
+```
+The training and network parameters for each experiment is stored in a json file in the [`configs/`](../configs) directory.
+
+
+
+
+
+
+## 2. Output, checkpoints and visualizations
+
+Output results are located in the folder defined by the variable `out_dir` in the corresponding config file (eg. [`configs/GatedGCN_ZINC_LSPE.json`](../configs/GatedGCN_ZINC_LSPE.json) file).
+
+If `out_dir = 'out/GatedGCN_ZINC_LSPE_noLapEigLoss/'`, then
+
+#### 2.1 To see checkpoints and results
+1. Go to`out/GatedGCN_ZINC_LSPE_noLapEigLoss/results` to view all result text files.
+2. Directory `out/GatedGCN_ZINC_LSPE_noLapEigLoss/checkpoints` contains model checkpoints.
+
+#### 2.2 To see the training logs in Tensorboard on local machine
+1. Go to the logs directory, i.e. `out/GatedGCN_ZINC_LSPE_noLapEigLoss/logs/`.
+2. Run the commands
+```
+source activate gnn_lspe
+tensorboard --logdir='./' --port 6006
+```
+3. Open `http://localhost:6006` in your browser. Note that the port information (here 6006 but it may change) appears on the terminal immediately after starting tensorboard.
+
+
+#### 2.3 To see the training logs in Tensorboard on remote machine
+1. Go to the logs directory, i.e. `out/GatedGCN_ZINC_LSPE_noLapEigLoss/logs/`.
+2. Run the [script](../scripts/TensorBoard/script_tensorboard.sh) with `bash script_tensorboard.sh`.
+3. On your local machine, run the command `ssh -N -f -L localhost:6006:localhost:6006 user@xx.xx.xx.xx`.
+4. Open `http://localhost:6006` in your browser. Note that `user@xx.xx.xx.xx` corresponds to your user login and the IP of the remote machine.
+
+
+
+
+
+## 3. Reproduce results
+
+
+```
+# At the root of the project
+
+bash scripts/ZINC/script_ZINC_all.sh
+bash scripts/OGBMOL/script_MOLTOX21_all.sh
+bash scripts/OGBMOL/script_MOLPCBA_all.sh
+
+```
+
+Scripts are [located](../scripts/) at the `scripts/` directory of the repository.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/gnn-lspe.png b/docs/gnn-lspe.png
new file mode 100644
index 0000000..c962007
Binary files /dev/null and b/docs/gnn-lspe.png differ
diff --git a/environment_cpu.yml b/environment_cpu.yml
new file mode 100644
index 0000000..9986997
--- /dev/null
+++ b/environment_cpu.yml
@@ -0,0 +1,44 @@
+name: gnn_lspe
+channels:
+- pytorch
+- dglteam
+- conda-forge
+- anaconda
+- defaults
+dependencies:
+- python=3.7.4
+- python-dateutil=2.8.0
+- pip=19.2.3
+- pytorch=1.6.0
+- torchvision==0.7.0
+- pillow==6.1
+- dgl=0.6.1
+- numpy=1.19.2
+- matplotlib=3.1.0
+- tensorboard=1.14.0
+- tensorboardx=1.8
+- future=0.18.2
+- absl-py
+- networkx=2.3
+- scikit-learn=0.21.2
+- scipy=1.3.0
+- notebook=6.0.0
+- h5py=2.9.0
+- mkl=2019.4
+- ipykernel=5.1.2
+- ipython=7.7.0
+- ipython_genutils=0.2.0
+- ipywidgets=7.5.1
+- jupyter=1.0.0
+- jupyter_client=5.3.1
+- jupyter_console=6.0.0
+- jupyter_core=4.5.0
+- plotly=4.1.1
+- scikit-image=0.15.0
+- requests==2.22.0
+- tqdm==4.43.0
+- pip:
+ - tensorflow==2.1.0
+ - tensorflow-estimator==2.1.0
+ - tensorboard==2.1.1
+ - ogb==1.3.1
\ No newline at end of file
diff --git a/environment_gpu.yml b/environment_gpu.yml
new file mode 100644
index 0000000..5192a66
--- /dev/null
+++ b/environment_gpu.yml
@@ -0,0 +1,47 @@
+name: gnn_lspe
+channels:
+- pytorch
+- dglteam
+- conda-forge
+- fragcolor
+- anaconda
+- defaults
+dependencies:
+- cudatoolkit=10.2
+- cudnn=7.6.5
+- python=3.7.4
+- python-dateutil=2.8.0
+- pip=19.2.3
+- pytorch=1.6.0
+- torchvision==0.7.0
+- pillow==6.1
+- dgl-cuda10.2=0.6.1
+- numpy=1.19.2
+- matplotlib=3.1.0
+- tensorboard=1.14.0
+- tensorboardx=1.8
+- future=0.18.2
+- absl-py
+- networkx=2.3
+- scikit-learn=0.21.2
+- scipy=1.3.0
+- notebook=6.0.0
+- h5py=2.9.0
+- mkl=2019.4
+- ipykernel=5.1.2
+- ipython=7.7.0
+- ipython_genutils=0.2.0
+- ipywidgets=7.5.1
+- jupyter=1.0.0
+- jupyter_client=5.3.1
+- jupyter_console=6.0.0
+- jupyter_core=4.5.0
+- plotly=4.1.1
+- scikit-image=0.15.0
+- requests==2.22.0
+- tqdm==4.43.0
+- pip:
+ - tensorflow-gpu==2.1.0
+ - tensorflow-estimator==2.1.0
+ - tensorboard==2.1.1
+ - ogb==1.3.1
\ No newline at end of file
diff --git a/layers/gatedgcn_layer.py b/layers/gatedgcn_layer.py
new file mode 100644
index 0000000..258fea4
--- /dev/null
+++ b/layers/gatedgcn_layer.py
@@ -0,0 +1,84 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import dgl.function as fn
+
+"""
+ GatedGCN: Residual Gated Graph ConvNets
+ An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent)
+ https://arxiv.org/pdf/1711.07553v2.pdf
+"""
+
+class GatedGCNLayer(nn.Module):
+ """
+ Param: []
+ """
+ def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False, graph_norm=True):
+ super().__init__()
+ self.in_channels = input_dim
+ self.out_channels = output_dim
+ self.dropout = dropout
+ self.batch_norm = batch_norm
+ self.graph_norm = graph_norm
+ self.residual = residual
+
+ if input_dim != output_dim:
+ self.residual = False
+
+ self.A = nn.Linear(input_dim, output_dim, bias=True)
+ self.B = nn.Linear(input_dim, output_dim, bias=True)
+ self.C = nn.Linear(input_dim, output_dim, bias=True)
+ self.D = nn.Linear(input_dim, output_dim, bias=True)
+ self.E = nn.Linear(input_dim, output_dim, bias=True)
+ self.bn_node_h = nn.BatchNorm1d(output_dim)
+ self.bn_node_e = nn.BatchNorm1d(output_dim)
+
+ def forward(self, g, h, p=None, e=None, snorm_n=None):
+
+ h_in = h # for residual connection
+ e_in = e # for residual connection
+
+ g.ndata['h'] = h
+ g.ndata['Ah'] = self.A(h)
+ g.ndata['Bh'] = self.B(h)
+ g.ndata['Dh'] = self.D(h)
+ g.ndata['Eh'] = self.E(h)
+ g.edata['e'] = e
+ g.edata['Ce'] = self.C(e)
+
+ g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
+ g.edata['e'] = g.edata['DEh'] + g.edata['Ce']
+ g.edata['sigma'] = torch.sigmoid(g.edata['e'])
+ g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h'))
+ g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
+ g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6)
+
+ h = g.ndata['h'] # result of graph convolution
+ e = g.edata['e'] # result of graph convolution
+
+ # GN from benchmarking-gnns-v1
+ if self.graph_norm:
+ h = h * snorm_n
+
+ if self.batch_norm:
+ h = self.bn_node_h(h) # batch normalization
+ e = self.bn_node_e(e) # batch normalization
+
+ h = F.relu(h) # non-linear activation
+ e = F.relu(e) # non-linear activation
+
+ if self.residual:
+ h = h_in + h # residual connection
+ e = e_in + e # residual connection
+
+ h = F.dropout(h, self.dropout, training=self.training)
+ e = F.dropout(e, self.dropout, training=self.training)
+
+ return h, None, e
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels)
+
+
\ No newline at end of file
diff --git a/layers/gatedgcn_lspe_layer.py b/layers/gatedgcn_lspe_layer.py
new file mode 100644
index 0000000..78fcfe0
--- /dev/null
+++ b/layers/gatedgcn_lspe_layer.py
@@ -0,0 +1,133 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import dgl.function as fn
+
+import dgl
+
+"""
+ GatedGCNLSPE: GatedGCN with LSPE
+"""
+
+class GatedGCNLSPELayer(nn.Module):
+ """
+ Param: []
+ """
+ def __init__(self, input_dim, output_dim, dropout, batch_norm, use_lapeig_loss=False, residual=False):
+ super().__init__()
+ self.in_channels = input_dim
+ self.out_channels = output_dim
+ self.dropout = dropout
+ self.batch_norm = batch_norm
+ self.residual = residual
+ self.use_lapeig_loss = use_lapeig_loss
+
+ if input_dim != output_dim:
+ self.residual = False
+
+ self.A1 = nn.Linear(input_dim*2, output_dim, bias=True)
+ self.A2 = nn.Linear(input_dim*2, output_dim, bias=True)
+ self.B1 = nn.Linear(input_dim, output_dim, bias=True)
+ self.B2 = nn.Linear(input_dim, output_dim, bias=True)
+ self.B3 = nn.Linear(input_dim, output_dim, bias=True)
+ self.C1 = nn.Linear(input_dim, output_dim, bias=True)
+ self.C2 = nn.Linear(input_dim, output_dim, bias=True)
+
+ self.bn_node_h = nn.BatchNorm1d(output_dim)
+ self.bn_node_e = nn.BatchNorm1d(output_dim)
+ # self.bn_node_p = nn.BatchNorm1d(output_dim)
+
+ def message_func_for_vij(self, edges):
+ hj = edges.src['h'] # h_j
+ pj = edges.src['p'] # p_j
+ vij = self.A2(torch.cat((hj, pj), -1))
+ return {'v_ij': vij}
+
+ def message_func_for_pj(self, edges):
+ pj = edges.src['p'] # p_j
+ return {'C2_pj': self.C2(pj)}
+
+ def compute_normalized_eta(self, edges):
+ return {'eta_ij': edges.data['sigma_hat_eta'] / (edges.dst['sum_sigma_hat_eta'] + 1e-6)} # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij'
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ with g.local_scope():
+
+ # for residual connection
+ h_in = h
+ p_in = p
+ e_in = e
+
+ # For the h's
+ g.ndata['h'] = h
+ g.ndata['A1_h'] = self.A1(torch.cat((h, p), -1))
+ # self.A2 being used in message_func_for_vij() function
+ g.ndata['B1_h'] = self.B1(h)
+ g.ndata['B2_h'] = self.B2(h)
+
+ # For the p's
+ g.ndata['p'] = p
+ g.ndata['C1_p'] = self.C1(p)
+ # self.C2 being used in message_func_for_pj() function
+
+ # For the e's
+ g.edata['e'] = e
+ g.edata['B3_e'] = self.B3(e)
+
+ #--------------------------------------------------------------------------------------#
+ # Calculation of h
+ g.apply_edges(fn.u_add_v('B1_h', 'B2_h', 'B1_B2_h'))
+ g.edata['hat_eta'] = g.edata['B1_B2_h'] + g.edata['B3_e']
+ g.edata['sigma_hat_eta'] = torch.sigmoid(g.edata['hat_eta'])
+ g.update_all(fn.copy_e('sigma_hat_eta', 'm'), fn.sum('m', 'sum_sigma_hat_eta')) # sum_j' sigma_hat_eta_ij'
+ g.apply_edges(self.compute_normalized_eta) # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij'
+ g.apply_edges(self.message_func_for_vij) # v_ij
+ g.edata['eta_mul_v'] = g.edata['eta_ij'] * g.edata['v_ij'] # eta_ij * v_ij
+ g.update_all(fn.copy_e('eta_mul_v', 'm'), fn.sum('m', 'sum_eta_v')) # sum_j eta_ij * v_ij
+ g.ndata['h'] = g.ndata['A1_h'] + g.ndata['sum_eta_v']
+
+ # Calculation of p
+ g.apply_edges(self.message_func_for_pj) # p_j
+ g.edata['eta_mul_p'] = g.edata['eta_ij'] * g.edata['C2_pj'] # eta_ij * C2_pj
+ g.update_all(fn.copy_e('eta_mul_p', 'm'), fn.sum('m', 'sum_eta_p')) # sum_j eta_ij * C2_pj
+ g.ndata['p'] = g.ndata['C1_p'] + g.ndata['sum_eta_p']
+
+ #--------------------------------------------------------------------------------------#
+
+ # passing towards output
+ h = g.ndata['h']
+ p = g.ndata['p']
+ e = g.edata['hat_eta']
+
+ # GN from benchmarking-gnns-v1
+ h = h * snorm_n
+
+ # batch normalization
+ if self.batch_norm:
+ h = self.bn_node_h(h)
+ e = self.bn_node_e(e)
+ # No BN for p
+
+ # non-linear activation
+ h = F.relu(h)
+ e = F.relu(e)
+ p = torch.tanh(p)
+
+ # residual connection
+ if self.residual:
+ h = h_in + h
+ p = p_in + p
+ e = e_in + e
+
+ # dropout
+ h = F.dropout(h, self.dropout, training=self.training)
+ p = F.dropout(p, self.dropout, training=self.training)
+ e = F.dropout(e, self.dropout, training=self.training)
+
+ return h, p, e
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels)
\ No newline at end of file
diff --git a/layers/graphit_gt_layer.py b/layers/graphit_gt_layer.py
new file mode 100644
index 0000000..972e40e
--- /dev/null
+++ b/layers/graphit_gt_layer.py
@@ -0,0 +1,273 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import dgl.function as fn
+import numpy as np
+
+"""
+ GraphiT-GT
+
+"""
+
+"""
+ Util functions
+"""
+def src_dot_dst(src_field, dst_field, out_field):
+ def func(edges):
+ return {out_field: (edges.src[src_field] * edges.dst[dst_field])}
+ return func
+
+
+def scaling(field, scale_constant):
+ def func(edges):
+ return {field: ((edges.data[field]) / scale_constant)}
+ return func
+
+# Improving implicit attention scores with explicit edge features, if available
+def imp_exp_attn(implicit_attn, explicit_edge):
+ """
+ implicit_attn: the output of K Q
+ explicit_edge: the explicit edge features
+ """
+ def func(edges):
+ return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
+ return func
+
+
+def exp(field):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
+ return func
+
+def adaptive_edge_PE(field, adaptive_weight):
+ def func(edges):
+ # initial shape was: adaptive_weight: [edges,1]; data: [edges, num_heads, 1]
+ # repeating adaptive_weight to have: [edges, num_heads, 1]
+ edges.data['tmp'] = edges.data[adaptive_weight].repeat(1, edges.data[field].shape[1]).unsqueeze(-1)
+ return {'score_soft': edges.data['tmp'] * edges.data[field]}
+ return func
+
+
+"""
+ Single Attention Head
+"""
+
+class MultiHeadAttentionLayer(nn.Module):
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, adaptive_edge_PE, attention_for):
+ super().__init__()
+
+
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.gamma = gamma
+ self.full_graph=full_graph
+ self.attention_for = attention_for
+ self.adaptive_edge_PE = adaptive_edge_PE
+
+ if self.attention_for == "h":
+ if use_bias:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ def propagate_attention(self, g):
+
+
+ if self.full_graph:
+ real_ids = torch.nonzero(g.edata['real']).squeeze()
+ fake_ids = torch.nonzero(g.edata['real']==0).squeeze()
+
+ else:
+ real_ids = g.edges(form='eid')
+
+ g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids)
+
+
+ # scale scores by sqrt(d)
+ g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
+
+ # Use available edge features to modify the scores for edges
+ g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids)
+
+ g.apply_edges(exp('score'))
+
+ # Adaptive weighting with k_RW_eij
+ # Only applicable to full graph, For NOW
+ if self.adaptive_edge_PE and self.full_graph:
+ g.apply_edges(adaptive_edge_PE('score_soft', 'k_RW'))
+ del g.edata['tmp']
+
+ # Send weighted values to target nodes
+ eids = g.edges()
+ g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV'))
+ g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z'))
+
+
+ def forward(self, g, h, e):
+
+ Q_h = self.Q(h)
+ K_h = self.K(h)
+ E = self.E(e)
+
+ if self.full_graph:
+ Q_2h = self.Q_2(h)
+ K_2h = self.K_2(h)
+ E_2 = self.E_2(e)
+
+ V_h = self.V(h)
+
+
+ # Reshaping into [num_nodes, num_heads, feat_dim] to
+ # get projections for multi-head attention
+ g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E'] = E.view(-1, self.num_heads, self.out_dim)
+
+
+ if self.full_graph:
+ g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim)
+
+ g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
+
+ self.propagate_attention(g)
+
+ h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6))
+
+ del g.ndata['wV']
+ del g.ndata['z']
+ del g.ndata['Q_h']
+ del g.ndata['K_h']
+ del g.edata['E']
+
+ if self.full_graph:
+ del g.ndata['Q_2h']
+ del g.ndata['K_2h']
+ del g.edata['E_2']
+
+ return h_out
+
+
+class GraphiT_GT_Layer(nn.Module):
+ """
+ Param:
+ """
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0,
+ layer_norm=False, batch_norm=True, residual=True, adaptive_edge_PE=False, use_bias=False):
+ super().__init__()
+
+ self.in_channels = in_dim
+ self.out_channels = out_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.residual = residual
+ self.layer_norm = layer_norm
+ self.batch_norm = batch_norm
+
+ self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, adaptive_edge_PE, attention_for="h")
+
+ self.O_h = nn.Linear(out_dim, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm1_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm1_h = nn.BatchNorm1d(out_dim)
+
+ # FFN for h
+ self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2)
+ self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm2_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm2_h = nn.BatchNorm1d(out_dim)
+
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in1 = h # for first residual connection
+
+ # [START] For calculation of h -----------------------------------------------------------------
+
+ # multi-head attention out
+ h_attn_out = self.attention_h(g, h, e)
+
+ #Concat multi-head outputs
+ h = h_attn_out.view(-1, self.out_channels)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ h = self.O_h(h)
+
+ if self.residual:
+ h = h_in1 + h # residual connection
+
+ # # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm1_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm1_h(h)
+
+ h_in2 = h # for second residual connection
+
+ # FFN for h
+ h = self.FFN_h_layer1(h)
+ h = F.relu(h)
+ h = F.dropout(h, self.dropout, training=self.training)
+ h = self.FFN_h_layer2(h)
+
+ if self.residual:
+ h = h_in2 + h # residual connection
+
+ # # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm2_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm2_h(h)
+
+ # [END] For calculation of h -----------------------------------------------------------------
+
+
+ return h, None
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels, self.num_heads, self.residual)
\ No newline at end of file
diff --git a/layers/graphit_gt_lspe_layer.py b/layers/graphit_gt_lspe_layer.py
new file mode 100644
index 0000000..9c81a75
--- /dev/null
+++ b/layers/graphit_gt_lspe_layer.py
@@ -0,0 +1,325 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import dgl.function as fn
+import numpy as np
+
+"""
+ GraphiT-GT-LSPE: GraphiT-GT with LSPE
+
+"""
+
+"""
+ Util functions
+"""
+def src_dot_dst(src_field, dst_field, out_field):
+ def func(edges):
+ return {out_field: (edges.src[src_field] * edges.dst[dst_field])}
+ return func
+
+
+def scaling(field, scale_constant):
+ def func(edges):
+ return {field: ((edges.data[field]) / scale_constant)}
+ return func
+
+# Improving implicit attention scores with explicit edge features, if available
+def imp_exp_attn(implicit_attn, explicit_edge):
+ """
+ implicit_attn: the output of K Q
+ explicit_edge: the explicit edge features
+ """
+ def func(edges):
+ return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
+ return func
+
+
+def exp(field):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
+ return func
+
+def adaptive_edge_PE(field, adaptive_weight):
+ def func(edges):
+ # initial shape was: adaptive_weight: [edges,1]; data: [edges, num_heads, 1]
+ # repeating adaptive_weight to have: [edges, num_heads, 1]
+ edges.data['tmp'] = edges.data[adaptive_weight].repeat(1, edges.data[field].shape[1]).unsqueeze(-1)
+ return {'score_soft': edges.data['tmp'] * edges.data[field]}
+ return func
+
+
+"""
+ Single Attention Head
+"""
+
+class MultiHeadAttentionLayer(nn.Module):
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, adaptive_edge_PE, attention_for):
+ super().__init__()
+
+
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.gamma = gamma
+ self.full_graph=full_graph
+ self.attention_for = attention_for
+ self.adaptive_edge_PE = adaptive_edge_PE
+
+ if self.attention_for == "h": # attention module for h has input h = [h,p], so 2*in_dim for Q,K,V
+ if use_bias:
+ self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+
+ elif self.attention_for == "p": # attention module for p
+ if use_bias:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ def propagate_attention(self, g):
+
+
+ if self.full_graph:
+ real_ids = torch.nonzero(g.edata['real']).squeeze()
+ fake_ids = torch.nonzero(g.edata['real']==0).squeeze()
+
+ else:
+ real_ids = g.edges(form='eid')
+
+ g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids)
+
+
+ # scale scores by sqrt(d)
+ g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
+
+ # Use available edge features to modify the scores for edges
+ g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids)
+
+ g.apply_edges(exp('score'))
+
+ # Adaptive weighting with k_RW_eij
+ # Only applicable to full graph, For NOW
+ if self.adaptive_edge_PE and self.full_graph:
+ g.apply_edges(adaptive_edge_PE('score_soft', 'k_RW'))
+ del g.edata['tmp']
+
+ # Send weighted values to target nodes
+ eids = g.edges()
+ g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV'))
+ g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z'))
+
+
+ def forward(self, g, h, p, e):
+ if self.attention_for == "h":
+ h = torch.cat((h, p), -1)
+ elif self.attention_for == "p":
+ h = p
+
+ Q_h = self.Q(h)
+ K_h = self.K(h)
+ E = self.E(e)
+
+ if self.full_graph:
+ Q_2h = self.Q_2(h)
+ K_2h = self.K_2(h)
+ E_2 = self.E_2(e)
+
+ V_h = self.V(h)
+
+
+ # Reshaping into [num_nodes, num_heads, feat_dim] to
+ # get projections for multi-head attention
+ g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E'] = E.view(-1, self.num_heads, self.out_dim)
+
+
+ if self.full_graph:
+ g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim)
+
+ g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
+
+ self.propagate_attention(g)
+
+ h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6))
+
+ del g.ndata['wV']
+ del g.ndata['z']
+ del g.ndata['Q_h']
+ del g.ndata['K_h']
+ del g.edata['E']
+
+ if self.full_graph:
+ del g.ndata['Q_2h']
+ del g.ndata['K_2h']
+ del g.edata['E_2']
+
+ return h_out
+
+
+class GraphiT_GT_LSPE_Layer(nn.Module):
+ """
+ Param:
+ """
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0,
+ layer_norm=False, batch_norm=True, residual=True, adaptive_edge_PE=False, use_bias=False):
+ super().__init__()
+
+ self.in_channels = in_dim
+ self.out_channels = out_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.residual = residual
+ self.layer_norm = layer_norm
+ self.batch_norm = batch_norm
+
+ self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, adaptive_edge_PE, attention_for="h")
+ self.attention_p = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, adaptive_edge_PE, attention_for="p")
+
+ self.O_h = nn.Linear(out_dim, out_dim)
+ self.O_p = nn.Linear(out_dim, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm1_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm1_h = nn.BatchNorm1d(out_dim)
+
+ # FFN for h
+ self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2)
+ self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm2_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm2_h = nn.BatchNorm1d(out_dim)
+
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in1 = h # for first residual connection
+ p_in1 = p # for first residual connection
+
+ # [START] For calculation of h -----------------------------------------------------------------
+
+ # multi-head attention out
+ h_attn_out = self.attention_h(g, h, p, e)
+
+ #Concat multi-head outputs
+ h = h_attn_out.view(-1, self.out_channels)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ h = self.O_h(h)
+
+ if self.residual:
+ h = h_in1 + h # residual connection
+
+ # # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm1_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm1_h(h)
+
+ h_in2 = h # for second residual connection
+
+ # FFN for h
+ h = self.FFN_h_layer1(h)
+ h = F.relu(h)
+ h = F.dropout(h, self.dropout, training=self.training)
+ h = self.FFN_h_layer2(h)
+
+ if self.residual:
+ h = h_in2 + h # residual connection
+
+ # # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm2_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm2_h(h)
+
+ # [END] For calculation of h -----------------------------------------------------------------
+
+
+ # [START] For calculation of p -----------------------------------------------------------------
+
+ # multi-head attention out
+ p_attn_out = self.attention_p(g, None, p, e)
+
+ #Concat multi-head outputs
+ p = p_attn_out.view(-1, self.out_channels)
+
+ p = F.dropout(p, self.dropout, training=self.training)
+
+ p = self.O_p(p)
+
+ p = torch.tanh(p)
+
+ if self.residual:
+ p = p_in1 + p # residual connection
+
+ # [END] For calculation of p -----------------------------------------------------------------
+
+ return h, p
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels, self.num_heads, self.residual)
\ No newline at end of file
diff --git a/layers/mlp_readout_layer.py b/layers/mlp_readout_layer.py
new file mode 100644
index 0000000..20a4463
--- /dev/null
+++ b/layers/mlp_readout_layer.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+"""
+ MLP Layer used after graph vector representation
+"""
+
+class MLPReadout(nn.Module):
+
+ def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
+ super().__init__()
+ list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
+ list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True ))
+ self.FC_layers = nn.ModuleList(list_FC_layers)
+ self.L = L
+
+ def forward(self, x):
+ y = x
+ for l in range(self.L):
+ y = self.FC_layers[l](y)
+ y = F.relu(y)
+ y = self.FC_layers[self.L](y)
+ return y
+
+
+
+class MLPReadout2(nn.Module):
+
+ def __init__(self, input_dim, output_dim, dropout_2=0.0, L=2): # L=nb_hidden_layers
+ super().__init__()
+ list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)]
+ list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))
+ self.FC_layers = nn.ModuleList(list_FC_layers)
+ self.L = L
+ self.dropout_2 = dropout_2
+
+ def forward(self, x):
+ y = x
+ for l in range(self.L):
+ y = F.dropout(y, self.dropout_2, training=self.training)
+ y = self.FC_layers[l](y)
+ y = F.relu(y)
+ y = self.FC_layers[self.L](y)
+ return y
\ No newline at end of file
diff --git a/layers/pna_layer.py b/layers/pna_layer.py
new file mode 100644
index 0000000..1140e39
--- /dev/null
+++ b/layers/pna_layer.py
@@ -0,0 +1,270 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import dgl.function as fn
+import dgl
+
+from .pna_utils import AGGREGATORS, SCALERS, MLP, FCLayer
+
+"""
+ PNA: Principal Neighbourhood Aggregation
+ Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic
+ https://arxiv.org/abs/2004.05718
+"""
+
+
+class PNATower(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d,
+ pretrans_layers, posttrans_layers, edge_features, edge_dim):
+ super().__init__()
+ self.dropout = dropout
+ self.graph_norm = graph_norm
+ self.batch_norm = batch_norm
+ self.edge_features = edge_features
+
+ self.batchnorm_h = nn.BatchNorm1d(out_dim)
+ self.aggregators = aggregators
+ self.scalers = scalers
+ self.pretrans_h = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none')
+
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+
+ self.avg_d = avg_d
+
+ def pretrans_edges(self, edges):
+ if self.edge_features:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1)
+ else:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
+
+ return {'e_for_h': self.pretrans_h(z2_for_h)}
+
+ # Message func for h
+ def message_func_for_h(self, edges):
+ return {'e_for_h': edges.data['e_for_h']}
+
+ # Reduce func for h
+ def reduce_func_for_h(self, nodes):
+ h = nodes.mailbox['e_for_h']
+ D = h.shape[-2]
+ h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
+ h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'h': h}
+
+ def forward(self, g, h, e, snorm_n):
+ g.ndata['h'] = h
+
+ if self.edge_features: # add the edges information only if edge_features = True
+ g.edata['ef'] = e
+
+ # pretransformation
+ g.apply_edges(self.pretrans_edges)
+
+ # aggregation for h
+ g.update_all(self.message_func_for_h, self.reduce_func_for_h)
+ h = torch.cat([h, g.ndata['h']], dim=1)
+
+ # posttransformation
+ h = self.posttrans_h(h)
+
+ # graph and batch normalization
+ if self.graph_norm:
+ h = h * snorm_n
+
+ if self.batch_norm:
+ h = self.batchnorm_h(h)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ return h
+
+
+class PNALayer(nn.Module):
+
+ def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, graph_norm, batch_norm, towers=1,
+ pretrans_layers=1, posttrans_layers=1, divide_input=True, residual=False, edge_features=False,
+ edge_dim=0):
+ """
+ :param in_dim: size of the input per node
+ :param out_dim: size of the output per node
+ :param aggregators: set of aggregation function identifiers
+ :param scalers: set of scaling functions identifiers
+ :param avg_d: average degree of nodes in the training set, used by scalers to normalize
+ :param dropout: dropout used
+ :param graph_norm: whether to use graph normalisation
+ :param batch_norm: whether to use batch normalisation
+ :param towers: number of towers to use
+ :param pretrans_layers: number of layers in the transformation before the aggregation
+ :param posttrans_layers: number of layers in the transformation after the aggregation
+ :param divide_input: whether the input features should be split between towers or not
+ :param residual: whether to add a residual connection
+ :param edge_features: whether to use the edge features
+ :param edge_dim: size of the edge features
+ """
+ super().__init__()
+ assert ((not divide_input) or in_dim % towers == 0), "if divide_input is set the number of towers has to divide in_dim"
+ assert (out_dim % towers == 0), "the number of towers has to divide the out_dim"
+ assert avg_d is not None
+
+ # retrieve the aggregators and scalers functions
+ aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
+ scalers = [SCALERS[scale] for scale in scalers.split()]
+
+ self.divide_input = divide_input
+ self.input_tower = in_dim // towers if divide_input else in_dim
+ self.output_tower = out_dim // towers
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.edge_features = edge_features
+ self.residual = residual
+ if in_dim != out_dim:
+ self.residual = False
+
+ # convolution
+ self.towers = nn.ModuleList()
+ for _ in range(towers):
+ self.towers.append(PNATower(in_dim=self.input_tower, out_dim=self.output_tower, aggregators=aggregators,
+ scalers=scalers, avg_d=avg_d, pretrans_layers=pretrans_layers,
+ posttrans_layers=posttrans_layers, batch_norm=batch_norm, dropout=dropout,
+ graph_norm=graph_norm, edge_features=edge_features, edge_dim=edge_dim))
+ # mixing network
+ self.mixing_network_h = FCLayer(out_dim, out_dim, activation='LeakyReLU')
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in = h # for residual connection
+
+ if self.divide_input:
+ tower_outs = [tower(g,
+ h[:, n_tower * self.input_tower: (n_tower + 1) * self.input_tower],
+ e,
+ snorm_n) for n_tower, tower in enumerate(self.towers)]
+ h_tower_outs = tower_outs
+ h_cat = torch.cat(h_tower_outs, dim=1)
+ else:
+ tower_outs = [tower(g, h, p, e, snorm_n) for tower in self.towers]
+ h_tower_outs = tower_outs
+ h_cat = torch.cat(h_tower_outs, dim=1)
+
+ h_out = self.mixing_network_h(h_cat)
+
+ if self.residual:
+ h_out = h_in + h_out # residual connection
+
+ return h_out, None
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)
+
+
+
+
+
+# This layer file below has no towers
+# and is similar to DGNLayerComplex used for best PNA score on MOLPCBA
+# implemented here https://github.com/Saro00/DGN/blob/master/models/dgl/dgn_layer.py
+
+class PNANoTowersLayer(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d,
+ pretrans_layers, posttrans_layers, residual, edge_features, edge_dim=0, use_lapeig_loss=False):
+ super().__init__()
+ self.dropout = dropout
+ self.graph_norm = graph_norm
+ self.batch_norm = batch_norm
+ self.edge_features = edge_features
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.residual = residual
+ if in_dim != out_dim:
+ self.residual = False
+
+ self.batchnorm_h = nn.BatchNorm1d(out_dim)
+
+ # retrieve the aggregators and scalers functions
+ aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
+ scalers = [SCALERS[scale] for scale in scalers.split()]
+
+ self.aggregators = aggregators
+ self.scalers = scalers
+
+ if self.edge_features:
+ self.pretrans_h = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none')
+
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+ else:
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers)) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+
+ self.avg_d = avg_d
+
+ def pretrans_edges(self, edges):
+ if self.edge_features:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1)
+ else:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
+
+ return {'e_for_h': self.pretrans_h(z2_for_h)}
+
+ # Message func for h
+ def message_func_for_h(self, edges):
+ return {'e_for_h': edges.data['e_for_h']}
+
+ # Reduce func for h
+ def reduce_func_for_h(self, nodes):
+ if self.edge_features:
+ h = nodes.mailbox['e_for_h']
+ else:
+ h = nodes.mailbox['m_h']
+ D = h.shape[-2]
+ h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
+ if len(self.scalers) > 1:
+ h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'h': h}
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ h_in = h # for residual connection
+
+ g.ndata['h'] = h
+
+ if self.edge_features: # add the edges information only if edge_features = True
+ g.edata['ef'] = e
+
+ if self.edge_features:
+ # pretransformation
+ g.apply_edges(self.pretrans_edges)
+
+ if self.edge_features:
+ # aggregation for h
+ g.update_all(self.message_func_for_h, self.reduce_func_for_h)
+ h = torch.cat([h, g.ndata['h']], dim=1)
+ else:
+ # aggregation for h
+ g.update_all(fn.copy_u('h', 'm_h'), self.reduce_func_for_h)
+ h = g.ndata['h']
+
+ # posttransformation
+ h = self.posttrans_h(h)
+
+ # graph and batch normalization
+ if self.graph_norm and self.edge_features:
+ h = h * snorm_n
+
+ if self.batch_norm:
+ h = self.batchnorm_h(h)
+
+ h = F.relu(h)
+
+ if self.residual:
+ h = h_in + h # residual connection
+
+ return h, None
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)
+
diff --git a/layers/pna_lspe_layer.py b/layers/pna_lspe_layer.py
new file mode 100644
index 0000000..4fd35ba
--- /dev/null
+++ b/layers/pna_lspe_layer.py
@@ -0,0 +1,350 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import dgl.function as fn
+import dgl
+
+from .pna_utils import AGGREGATORS, SCALERS, MLP, FCLayer
+
+"""
+ PNALSPE: PNA with LSPE
+
+ PNA: Principal Neighbourhood Aggregation
+ Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic
+ https://arxiv.org/abs/2004.05718
+"""
+
+
+class PNATower(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d,
+ pretrans_layers, posttrans_layers, edge_features, edge_dim):
+ super().__init__()
+ self.dropout = dropout
+ self.graph_norm = graph_norm
+ self.batch_norm = batch_norm
+ self.edge_features = edge_features
+
+ self.batchnorm_h = nn.BatchNorm1d(out_dim)
+ self.aggregators = aggregators
+ self.scalers = scalers
+ self.pretrans_h = MLP(in_size=2 * 2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none')
+ self.pretrans_p = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='tanh', last_activation='none')
+
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 2) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+ self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none')
+
+ self.avg_d = avg_d
+
+ def pretrans_edges(self, edges):
+ if self.edge_features:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1)
+ z2_for_p = torch.cat([edges.src['p'], edges.dst['p'], edges.data['ef']], dim=1)
+ else:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
+ z2_for_p = torch.cat([edges.src['p'], edges.dst['p']], dim=1)
+
+ return {'e_for_h': self.pretrans_h(z2_for_h), 'e_for_p': self.pretrans_p(z2_for_p)}
+
+ # Message func for h
+ def message_func_for_h(self, edges):
+ return {'e_for_h': edges.data['e_for_h']}
+
+ # Reduce func for h
+ def reduce_func_for_h(self, nodes):
+ h = nodes.mailbox['e_for_h']
+ D = h.shape[-2]
+ h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
+ h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'h': h}
+
+ # Message func for p
+ def message_func_for_p(self, edges):
+ return {'e_for_p': edges.data['e_for_p']}
+
+ # Reduce func for p
+ def reduce_func_for_p(self, nodes):
+ p = nodes.mailbox['e_for_p']
+ D = p.shape[-2]
+ p = torch.cat([aggregate(p) for aggregate in self.aggregators], dim=1)
+ p = torch.cat([scale(p, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'p': p}
+
+ def forward(self, g, h, p, e, snorm_n):
+ g.ndata['h'] = h
+ g.ndata['p'] = p
+
+ if self.edge_features: # add the edges information only if edge_features = True
+ g.edata['ef'] = e
+
+ # pretransformation
+ g.apply_edges(self.pretrans_edges)
+
+ # aggregation for h
+ g.update_all(self.message_func_for_h, self.reduce_func_for_h)
+ h = torch.cat([h, g.ndata['h']], dim=1)
+
+ # aggregation for p
+ g.update_all(self.message_func_for_p, self.reduce_func_for_p)
+ p = torch.cat([p, g.ndata['p']], dim=1)
+
+ # posttransformation
+ h = self.posttrans_h(h)
+ p = self.posttrans_p(p)
+
+ # graph and batch normalization
+ if self.graph_norm:
+ h = h * snorm_n
+
+ if self.batch_norm:
+ h = self.batchnorm_h(h)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+ p = F.dropout(p, self.dropout, training=self.training)
+
+ return h, p
+
+
+class PNALSPELayer(nn.Module):
+
+ def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, graph_norm, batch_norm, towers=1,
+ pretrans_layers=1, posttrans_layers=1, divide_input=True, residual=False, edge_features=False,
+ edge_dim=0):
+ """
+ :param in_dim: size of the input per node
+ :param out_dim: size of the output per node
+ :param aggregators: set of aggregation function identifiers
+ :param scalers: set of scaling functions identifiers
+ :param avg_d: average degree of nodes in the training set, used by scalers to normalize
+ :param dropout: dropout used
+ :param graph_norm: whether to use graph normalisation
+ :param batch_norm: whether to use batch normalisation
+ :param towers: number of towers to use
+ :param pretrans_layers: number of layers in the transformation before the aggregation
+ :param posttrans_layers: number of layers in the transformation after the aggregation
+ :param divide_input: whether the input features should be split between towers or not
+ :param residual: whether to add a residual connection
+ :param edge_features: whether to use the edge features
+ :param edge_dim: size of the edge features
+ """
+ super().__init__()
+ assert ((not divide_input) or in_dim % towers == 0), "if divide_input is set the number of towers has to divide in_dim"
+ assert (out_dim % towers == 0), "the number of towers has to divide the out_dim"
+ assert avg_d is not None
+
+ # retrieve the aggregators and scalers functions
+ aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
+ scalers = [SCALERS[scale] for scale in scalers.split()]
+
+ self.divide_input = divide_input
+ self.input_tower = in_dim // towers if divide_input else in_dim
+ self.output_tower = out_dim // towers
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.edge_features = edge_features
+ self.residual = residual
+ if in_dim != out_dim:
+ self.residual = False
+
+ # convolution
+ self.towers = nn.ModuleList()
+ for _ in range(towers):
+ self.towers.append(PNATower(in_dim=self.input_tower, out_dim=self.output_tower, aggregators=aggregators,
+ scalers=scalers, avg_d=avg_d, pretrans_layers=pretrans_layers,
+ posttrans_layers=posttrans_layers, batch_norm=batch_norm, dropout=dropout,
+ graph_norm=graph_norm, edge_features=edge_features, edge_dim=edge_dim))
+ # mixing network
+ self.mixing_network_h = FCLayer(out_dim, out_dim, activation='LeakyReLU')
+ self.mixing_network_p = FCLayer(out_dim, out_dim, activation='tanh')
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in = h # for residual connection
+ p_in = p # for residual connection
+
+ # Concating p to h, as in PEGNN
+ h = torch.cat((h, p), -1)
+
+ if self.divide_input:
+ tower_outs = [tower(g,
+ h[:, n_tower * 2 * self.input_tower: (n_tower + 1) * 2 * self.input_tower],
+ p[:, n_tower * self.input_tower: (n_tower + 1) * self.input_tower],
+ e,
+ snorm_n) for n_tower, tower in enumerate(self.towers)]
+ h_tower_outs, p_tower_outs = map(list,zip(*tower_outs))
+ h_cat = torch.cat(h_tower_outs, dim=1)
+ p_cat = torch.cat(p_tower_outs, dim=1)
+ else:
+ tower_outs = [tower(g, h, p, e, snorm_n) for tower in self.towers]
+ h_tower_outs, p_tower_outs = map(list,zip(*tower_outs))
+ h_cat = torch.cat(h_tower_outs, dim=1)
+ p_cat = torch.cat(p_tower_outs, dim=1)
+
+ h_out = self.mixing_network_h(h_cat)
+ p_out = self.mixing_network_p(p_cat)
+
+
+ if self.residual:
+ h_out = h_in + h_out # residual connection
+ p_out = p_in + p_out # residual connection
+
+ return h_out, p_out
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)
+
+
+
+
+
+# This layer file below has no towers
+# and is similar to DGNLayerComplex used for best PNA score on MOLPCBA
+# implemented here https://github.com/Saro00/DGN/blob/master/models/dgl/dgn_layer.py
+
+class PNANoTowersLSPELayer(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d,
+ pretrans_layers, posttrans_layers, residual, edge_features, edge_dim=0, use_lapeig_loss=False):
+ super().__init__()
+ self.dropout = dropout
+ self.graph_norm = graph_norm
+ self.batch_norm = batch_norm
+ self.edge_features = edge_features
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.residual = residual
+ if in_dim != out_dim:
+ self.residual = False
+
+ self.use_lapeig_loss = use_lapeig_loss
+
+ self.batchnorm_h = nn.BatchNorm1d(out_dim)
+
+ # retrieve the aggregators and scalers functions
+ aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
+ scalers = [SCALERS[scale] for scale in scalers.split()]
+
+ self.aggregators = aggregators
+ self.scalers = scalers
+
+ if self.edge_features:
+ self.pretrans_h = MLP(in_size=2 * 2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none')
+ self.pretrans_p = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim,
+ out_size=in_dim, layers=pretrans_layers, mid_activation='tanh', last_activation='none')
+
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 2) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+ self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none')
+ else:
+ self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers)) * 2 * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none')
+ self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers)) * in_dim, hidden_size=out_dim,
+ out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none')
+
+ self.avg_d = avg_d
+
+ def pretrans_edges(self, edges):
+ if self.edge_features:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1)
+ z2_for_p = torch.cat([edges.src['p'], edges.dst['p'], edges.data['ef']], dim=1)
+ else:
+ z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
+ z2_for_p = torch.cat([edges.src['p'], edges.dst['p']], dim=1)
+
+ return {'e_for_h': self.pretrans_h(z2_for_h), 'e_for_p': self.pretrans_p(z2_for_p)}
+
+ # Message func for h
+ def message_func_for_h(self, edges):
+ return {'e_for_h': edges.data['e_for_h']}
+
+ # Reduce func for h
+ def reduce_func_for_h(self, nodes):
+ if self.edge_features:
+ h = nodes.mailbox['e_for_h']
+ else:
+ h = nodes.mailbox['m_h']
+ D = h.shape[-2]
+ h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
+ if len(self.scalers) > 1:
+ h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'h': h}
+
+ # Message func for p
+ def message_func_for_p(self, edges):
+ return {'e_for_p': edges.data['e_for_p']}
+
+ # Reduce func for p
+ def reduce_func_for_p(self, nodes):
+ if self.edge_features:
+ p = nodes.mailbox['e_for_p']
+ else:
+ p = nodes.mailbox['m_p']
+ D = p.shape[-2]
+ p = torch.cat([aggregate(p) for aggregate in self.aggregators], dim=1)
+ if len(self.scalers) > 1:
+ p = torch.cat([scale(p, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
+ return {'p': p}
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = F.dropout(h, self.dropout, training=self.training)
+ p = F.dropout(p, self.dropout, training=self.training)
+
+ h_in = h # for residual connection
+ p_in = p # for residual connection
+
+ # Concating p to h, as in PEGNN
+ h = torch.cat((h, p), -1)
+
+ g.ndata['h'] = h
+ g.ndata['p'] = p
+
+ if self.edge_features: # add the edges information only if edge_features = True
+ g.edata['ef'] = e
+
+ if self.edge_features:
+ # pretransformation
+ g.apply_edges(self.pretrans_edges)
+
+ if self.edge_features:
+ # aggregation for h
+ g.update_all(self.message_func_for_h, self.reduce_func_for_h)
+ h = torch.cat([h, g.ndata['h']], dim=1)
+
+ # aggregation for p
+ g.update_all(self.message_func_for_p, self.reduce_func_for_p)
+ p = torch.cat([p, g.ndata['p']], dim=1)
+ else:
+ # aggregation for h
+ g.update_all(fn.copy_u('h', 'm_h'), self.reduce_func_for_h)
+ h = g.ndata['h']
+
+ # aggregation for p
+ g.update_all(fn.copy_u('p', 'm_p'), self.reduce_func_for_p)
+ p = g.ndata['p']
+
+ # posttransformation
+ h = self.posttrans_h(h)
+ p = self.posttrans_p(p)
+
+ # graph and batch normalization
+ if self.graph_norm and self.edge_features:
+ h = h * snorm_n
+
+ if self.batch_norm:
+ h = self.batchnorm_h(h)
+
+ h = F.relu(h)
+ p = torch.tanh(h)
+
+ if self.residual:
+ h = h_in + h # residual connection
+ h = p_in + p # residual connection
+
+ return h, p
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)
\ No newline at end of file
diff --git a/layers/pna_utils.py b/layers/pna_utils.py
new file mode 100644
index 0000000..61c9fd5
--- /dev/null
+++ b/layers/pna_utils.py
@@ -0,0 +1,407 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+
+# PNA Aggregators ------------------------------------------------------------------------------
+
+EPS = 1e-5
+
+
+def aggregate_mean(h):
+ return torch.mean(h, dim=1)
+
+
+def aggregate_max(h):
+ return torch.max(h, dim=1)[0]
+
+
+def aggregate_min(h):
+ return torch.min(h, dim=1)[0]
+
+
+def aggregate_std(h):
+ return torch.sqrt(aggregate_var(h) + EPS)
+
+
+def aggregate_var(h):
+ h_mean_squares = torch.mean(h * h, dim=-2)
+ h_mean = torch.mean(h, dim=-2)
+ var = torch.relu(h_mean_squares - h_mean * h_mean)
+ return var
+
+
+def aggregate_moment(h, n=3):
+ # for each node (E[(X-E[X])^n])^{1/n}
+ # EPS is added to the absolute value of expectation before taking the nth root for stability
+ h_mean = torch.mean(h, dim=1, keepdim=True)
+ h_n = torch.mean(torch.pow(h - h_mean, n))
+ rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1. / n)
+ return rooted_h_n
+
+
+def aggregate_moment_3(h):
+ return aggregate_moment(h, n=3)
+
+
+def aggregate_moment_4(h):
+ return aggregate_moment(h, n=4)
+
+
+def aggregate_moment_5(h):
+ return aggregate_moment(h, n=5)
+
+
+def aggregate_sum(h):
+ return torch.sum(h, dim=1)
+
+
+AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
+ 'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3, 'moment4': aggregate_moment_4,
+ 'moment5': aggregate_moment_5}
+
+
+
+
+# PNA Scalers ---------------------------------------------------------------------------------
+
+
+# each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and
+# avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output
+
+def scale_identity(h, D=None, avg_d=None):
+ return h
+
+
+def scale_amplification(h, D, avg_d):
+ # log(D + 1) / d * h where d is the average of the ``log(D + 1)`` in the training set
+ return h * (np.log(D + 1) / avg_d["log"])
+
+
+def scale_attenuation(h, D, avg_d):
+ # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set
+ return h * (avg_d["log"] / np.log(D + 1))
+
+
+SCALERS = {'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation}
+
+
+
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+SUPPORTED_ACTIVATION_MAP = {'ReLU', 'Sigmoid', 'Tanh', 'ELU', 'SELU', 'GLU', 'LeakyReLU', 'Softplus', 'None'}
+
+
+def get_activation(activation):
+ """ returns the activation function represented by the input string """
+ if activation and callable(activation):
+ # activation is already a function
+ return activation
+ # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation
+ activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()]
+ assert len(activation) == 1 and isinstance(activation[0], str), 'Unhandled activation function'
+ activation = activation[0]
+ if activation.lower() == 'none':
+ return None
+ return vars(torch.nn.modules.activation)[activation]()
+
+
+class Set2Set(torch.nn.Module):
+ r"""
+ Set2Set global pooling operator from the `"Order Matters: Sequence to sequence for sets"
+ `_ paper. This pooling layer performs the following operation
+
+ .. math::
+ \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})
+
+ \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)
+
+ \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i
+
+ \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,
+
+ where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice
+ the dimensionality as the input.
+
+ Arguments
+ ---------
+ input_dim: int
+ Size of each input sample.
+ hidden_dim: int, optional
+ the dim of set representation which corresponds to the input dim of the LSTM in Set2Set.
+ This is typically the sum of the input dim and the lstm output dim. If not provided, it will be set to :obj:`input_dim*2`
+ steps: int, optional
+ Number of iterations :math:`T`. If not provided, the number of nodes will be used.
+ num_layers : int, optional
+ Number of recurrent layers (e.g., :obj:`num_layers=2` would mean stacking two LSTMs together)
+ (Default, value = 1)
+ """
+
+ def __init__(self, nin, nhid=None, steps=None, num_layers=1, activation=None, device='cpu'):
+ super(Set2Set, self).__init__()
+ self.steps = steps
+ self.nin = nin
+ self.nhid = nin * 2 if nhid is None else nhid
+ if self.nhid <= self.nin:
+ raise ValueError('Set2Set hidden_dim should be larger than input_dim')
+ # the hidden is a concatenation of weighted sum of embedding and LSTM output
+ self.lstm_output_dim = self.nhid - self.nin
+ self.num_layers = num_layers
+ self.lstm = nn.LSTM(self.nhid, self.nin, num_layers=num_layers, batch_first=True).to(device)
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, x):
+ r"""
+ Applies the pooling on input tensor x
+
+ Arguments
+ ----------
+ x: torch.FloatTensor
+ Input tensor of size (B, N, D)
+
+ Returns
+ -------
+ x: `torch.FloatTensor`
+ Tensor resulting from the set2set pooling operation.
+ """
+
+ batch_size = x.shape[0]
+ n = self.steps or x.shape[1]
+
+ h = (x.new_zeros((self.num_layers, batch_size, self.nin)),
+ x.new_zeros((self.num_layers, batch_size, self.nin)))
+
+ q_star = x.new_zeros(batch_size, 1, self.nhid)
+
+ for i in range(n):
+ # q: batch_size x 1 x input_dim
+ q, h = self.lstm(q_star, h)
+ # e: batch_size x n x 1
+ e = torch.matmul(x, torch.transpose(q, 1, 2))
+ a = self.softmax(e)
+ r = torch.sum(a * x, dim=1, keepdim=True)
+ q_star = torch.cat([q, r], dim=-1)
+
+ return torch.squeeze(q_star, dim=1)
+
+
+class FCLayer(nn.Module):
+ r"""
+ A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module.
+ The order in which transformations are applied is:
+
+ #. Dense Layer
+ #. Activation
+ #. Dropout (if applicable)
+ #. Batch Normalization (if applicable)
+
+ Arguments
+ ----------
+ in_size: int
+ Input dimension of the layer (the torch.nn.Linear)
+ out_size: int
+ Output dimension of the layer.
+ dropout: float, optional
+ The ratio of units to dropout. No dropout by default.
+ (Default value = 0.)
+ activation: str or callable, optional
+ Activation function to use.
+ (Default value = relu)
+ b_norm: bool, optional
+ Whether to use batch normalization
+ (Default value = False)
+ bias: bool, optional
+ Whether to enable bias in for the linear layer.
+ (Default value = True)
+ init_fn: callable, optional
+ Initialization function to use for the weight of the layer. Default is
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` with :math:`k=\frac{1}{ \text{in_size}}`
+ (Default value = None)
+
+ Attributes
+ ----------
+ dropout: int
+ The ratio of units to dropout.
+ b_norm: int
+ Whether to use batch normalization
+ linear: torch.nn.Linear
+ The linear layer
+ activation: the torch.nn.Module
+ The activation layer
+ init_fn: function
+ Initialization function used for the weight of the layer
+ in_size: int
+ Input dimension of the linear layer
+ out_size: int
+ Output dimension of the linear layer
+ """
+
+ def __init__(self, in_size, out_size, activation='relu', dropout=0., b_norm=False, bias=True, init_fn=None,
+ device='cpu'):
+ super(FCLayer, self).__init__()
+
+ self.__params = locals()
+ del self.__params['__class__']
+ del self.__params['self']
+ self.in_size = in_size
+ self.out_size = out_size
+ self.bias = bias
+ self.linear = nn.Linear(in_size, out_size, bias=bias).to(device)
+ self.dropout = None
+ self.b_norm = None
+ if dropout:
+ self.dropout = nn.Dropout(p=dropout)
+ if b_norm:
+ self.b_norm = nn.BatchNorm1d(out_size).to(device)
+ self.activation = get_activation(activation)
+ self.init_fn = nn.init.xavier_uniform_
+
+ self.reset_parameters()
+
+ def reset_parameters(self, init_fn=None):
+ init_fn = init_fn or self.init_fn
+ if init_fn is not None:
+ init_fn(self.linear.weight, 1 / self.in_size)
+ if self.bias:
+ self.linear.bias.data.zero_()
+
+ def forward(self, x):
+ h = self.linear(x)
+ if self.activation is not None:
+ h = self.activation(h)
+ if self.dropout is not None:
+ h = self.dropout(h)
+ if self.b_norm is not None:
+ if h.shape[1] != self.out_size:
+ h = self.b_norm(h.transpose(1, 2)).transpose(1, 2)
+ else:
+ h = self.b_norm(h)
+ return h
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + str(self.in_size) + ' -> ' \
+ + str(self.out_size) + ')'
+
+
+class MLP(nn.Module):
+ """
+ Simple multi-layer perceptron, built of a series of FCLayers
+ """
+
+ def __init__(self, in_size, hidden_size, out_size, layers, mid_activation='relu', last_activation='none',
+ dropout=0., mid_b_norm=False, last_b_norm=False, device='cpu'):
+ super(MLP, self).__init__()
+
+ self.in_size = in_size
+ self.hidden_size = hidden_size
+ self.out_size = out_size
+
+ self.fully_connected = nn.ModuleList()
+ if layers <= 1:
+ self.fully_connected.append(FCLayer(in_size, out_size, activation=last_activation, b_norm=last_b_norm,
+ device=device, dropout=dropout))
+ else:
+ self.fully_connected.append(FCLayer(in_size, hidden_size, activation=mid_activation, b_norm=mid_b_norm,
+ device=device, dropout=dropout))
+ for _ in range(layers - 2):
+ self.fully_connected.append(FCLayer(hidden_size, hidden_size, activation=mid_activation,
+ b_norm=mid_b_norm, device=device, dropout=dropout))
+ self.fully_connected.append(FCLayer(hidden_size, out_size, activation=last_activation, b_norm=last_b_norm,
+ device=device, dropout=dropout))
+
+ def forward(self, x):
+ for fc in self.fully_connected:
+ x = fc(x)
+ return x
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + str(self.in_size) + ' -> ' \
+ + str(self.out_size) + ')'
+
+
+class GRU(nn.Module):
+ """
+ Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
+ """
+
+ def __init__(self, input_size, hidden_size, device):
+ super(GRU, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device)
+
+ def forward(self, x, y):
+ """
+ :param x: shape: (B, N, Din) where Din <= input_size (difference is padded)
+ :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded)
+ :return: shape: (B, N, Dh)
+ """
+ assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size)
+
+ (B, N, _) = x.shape
+ x = x.reshape(1, B * N, -1).contiguous()
+ y = y.reshape(1, B * N, -1).contiguous()
+
+ # padding if necessary
+ if x.shape[-1] < self.input_size:
+ x = F.pad(input=x, pad=[0, self.input_size - x.shape[-1]], mode='constant', value=0)
+ if y.shape[-1] < self.hidden_size:
+ y = F.pad(input=y, pad=[0, self.hidden_size - y.shape[-1]], mode='constant', value=0)
+
+ x = self.gru(x, y)[1]
+ x = x.reshape(B, N, -1)
+ return x
+
+
+class S2SReadout(nn.Module):
+ """
+ Performs a Set2Set aggregation of all the graph nodes' features followed by a series of fully connected layers
+ """
+
+ def __init__(self, in_size, hidden_size, out_size, fc_layers=3, device='cpu', final_activation='relu'):
+ super(S2SReadout, self).__init__()
+
+ # set2set aggregation
+ self.set2set = Set2Set(in_size, device=device)
+
+ # fully connected layers
+ self.mlp = MLP(in_size=2 * in_size, hidden_size=hidden_size, out_size=out_size, layers=fc_layers,
+ mid_activation="relu", last_activation=final_activation, mid_b_norm=True, last_b_norm=False,
+ device=device)
+
+ def forward(self, x):
+ x = self.set2set(x)
+ return self.mlp(x)
+
+
+
+class GRU(nn.Module):
+ """
+ Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
+ """
+
+ def __init__(self, input_size, hidden_size, device):
+ super(GRU, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device)
+
+ def forward(self, x, y):
+ """
+ :param x: shape: (B, N, Din) where Din <= input_size (difference is padded)
+ :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded)
+ :return: shape: (B, N, Dh)
+ """
+ assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size)
+ x = x.unsqueeze(0)
+ y = y.unsqueeze(0)
+ x = self.gru(x, y)[1]
+ x = x.squeeze()
+ return x
diff --git a/layers/san_gt_layer.py b/layers/san_gt_layer.py
new file mode 100644
index 0000000..e71f1b5
--- /dev/null
+++ b/layers/san_gt_layer.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import dgl.function as fn
+import numpy as np
+
+"""
+ SAN-GT
+
+"""
+
+"""
+ Util functions
+"""
+def src_dot_dst(src_field, dst_field, out_field):
+ def func(edges):
+ return {out_field: (edges.src[src_field] * edges.dst[dst_field])}
+ return func
+
+
+def scaling(field, scale_constant):
+ def func(edges):
+ return {field: ((edges.data[field]) / scale_constant)}
+ return func
+
+# Improving implicit attention scores with explicit edge features, if available
+def imp_exp_attn(implicit_attn, explicit_edge):
+ """
+ implicit_attn: the output of K Q
+ explicit_edge: the explicit edge features
+ """
+ def func(edges):
+ return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
+ return func
+
+def exp_real(field, L):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)}
+ return func
+
+
+def exp_fake(field, L):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': L*torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)}
+ return func
+
+def exp(field):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
+ return func
+
+
+"""
+ Single Attention Head
+"""
+
+class MultiHeadAttentionLayer(nn.Module):
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, attention_for):
+ super().__init__()
+
+
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.full_graph=full_graph
+ self.attention_for = attention_for
+ self.gamma = gamma
+
+ if self.attention_for == "h":
+ if use_bias:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+
+ def propagate_attention(self, g):
+
+
+ if self.full_graph:
+ real_ids = torch.nonzero(g.edata['real']).squeeze()
+ fake_ids = torch.nonzero(g.edata['real']==0).squeeze()
+
+ else:
+ real_ids = g.edges(form='eid')
+
+ g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids)
+
+
+ # scale scores by sqrt(d)
+ g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
+
+ # Use available edge features to modify the scores for edges
+ g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids)
+
+
+ if self.full_graph:
+ # softmax and scaling by gamma
+ L = torch.clamp(self.gamma, min=0.0, max=1.0) # Gamma \in [0,1]
+ g.apply_edges(exp_real('score', L), edges=real_ids)
+ g.apply_edges(exp_fake('score', L), edges=fake_ids)
+
+ else:
+ g.apply_edges(exp('score'), edges=real_ids)
+
+ # Send weighted values to target nodes
+ eids = g.edges()
+ g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV'))
+ g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z'))
+
+
+ def forward(self, g, h, e):
+
+ Q_h = self.Q(h)
+ K_h = self.K(h)
+ E = self.E(e)
+
+ if self.full_graph:
+ Q_2h = self.Q_2(h)
+ K_2h = self.K_2(h)
+ E_2 = self.E_2(e)
+
+ V_h = self.V(h)
+
+
+ # Reshaping into [num_nodes, num_heads, feat_dim] to
+ # get projections for multi-head attention
+ g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E'] = E.view(-1, self.num_heads, self.out_dim)
+
+
+ if self.full_graph:
+ g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim)
+
+ g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
+
+ self.propagate_attention(g)
+
+ h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6))
+
+ return h_out
+
+
+class SAN_GT_Layer(nn.Module):
+ """
+ Param:
+ """
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0,
+ layer_norm=False, batch_norm=True, residual=True, use_bias=False):
+ super().__init__()
+
+ self.in_channels = in_dim
+ self.out_channels = out_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.residual = residual
+ self.layer_norm = layer_norm
+ self.batch_norm = batch_norm
+
+ self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, attention_for="h")
+
+ self.O_h = nn.Linear(out_dim, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm1_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm1_h = nn.BatchNorm1d(out_dim)
+
+ # FFN for h
+ self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2)
+ self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm2_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm2_h = nn.BatchNorm1d(out_dim)
+
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in1 = h # for first residual connection
+
+ # [START] For calculation of h -----------------------------------------------------------------
+
+ # multi-head attention out
+ h_attn_out = self.attention_h(g, h, e)
+
+ #Concat multi-head outputs
+ h = h_attn_out.view(-1, self.out_channels)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ h = self.O_h(h)
+
+ if self.residual:
+ h = h_in1 + h # residual connection
+
+ # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm1_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm1_h(h)
+
+ h_in2 = h # for second residual connection
+
+ # FFN for h
+ h = self.FFN_h_layer1(h)
+ h = F.relu(h)
+ h = F.dropout(h, self.dropout, training=self.training)
+ h = self.FFN_h_layer2(h)
+
+ if self.residual:
+ h = h_in2 + h # residual connection
+
+ # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm2_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm2_h(h)
+
+ # [END] For calculation of h -----------------------------------------------------------------
+
+ return h, None
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels, self.num_heads, self.residual)
\ No newline at end of file
diff --git a/layers/san_gt_lspe_layer.py b/layers/san_gt_lspe_layer.py
new file mode 100644
index 0000000..e18361c
--- /dev/null
+++ b/layers/san_gt_lspe_layer.py
@@ -0,0 +1,318 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import dgl.function as fn
+import numpy as np
+
+"""
+ SAN-GT-LSPE: SAN-GT with LSPE
+
+"""
+
+"""
+ Util functions
+"""
+def src_dot_dst(src_field, dst_field, out_field):
+ def func(edges):
+ return {out_field: (edges.src[src_field] * edges.dst[dst_field])}
+ return func
+
+
+def scaling(field, scale_constant):
+ def func(edges):
+ return {field: ((edges.data[field]) / scale_constant)}
+ return func
+
+# Improving implicit attention scores with explicit edge features, if available
+def imp_exp_attn(implicit_attn, explicit_edge):
+ """
+ implicit_attn: the output of K Q
+ explicit_edge: the explicit edge features
+ """
+ def func(edges):
+ return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
+ return func
+
+def exp_real(field, L):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)}
+ return func
+
+
+def exp_fake(field, L):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': L*torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)}
+ return func
+
+def exp(field):
+ def func(edges):
+ # clamp for softmax numerical stability
+ return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
+ return func
+
+
+"""
+ Single Attention Head
+"""
+
+class MultiHeadAttentionLayer(nn.Module):
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, attention_for):
+ super().__init__()
+
+
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.full_graph=full_graph
+ self.attention_for = attention_for
+ self.gamma = gamma
+
+ if self.attention_for == "h": # attention module for h has input h = [h,p], so 2*in_dim for Q,K,V
+ if use_bias:
+ self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=False)
+
+ elif self.attention_for == "p": # attention module for p
+ if use_bias:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
+
+ else:
+ self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ if self.full_graph:
+ self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+ self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
+
+ def propagate_attention(self, g):
+
+ if self.full_graph:
+ real_ids = torch.nonzero(g.edata['real']).squeeze()
+ fake_ids = torch.nonzero(g.edata['real']==0).squeeze()
+
+ else:
+ real_ids = g.edges(form='eid')
+
+ g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids)
+
+
+ # scale scores by sqrt(d)
+ g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
+
+ # Use available edge features to modify the scores for edges
+ g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids)
+
+ if self.full_graph:
+ g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids)
+
+
+ if self.full_graph:
+ # softmax and scaling by gamma
+ L = torch.clamp(self.gamma, min=0.0, max=1.0) # Gamma \in [0,1]
+ g.apply_edges(exp_real('score', L), edges=real_ids)
+ g.apply_edges(exp_fake('score', L), edges=fake_ids)
+
+ else:
+ g.apply_edges(exp('score'), edges=real_ids)
+
+ # Send weighted values to target nodes
+ eids = g.edges()
+ g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV'))
+ g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z'))
+
+
+ def forward(self, g, h, p, e):
+ if self.attention_for == "h":
+ h = torch.cat((h, p), -1)
+ elif self.attention_for == "p":
+ h = p
+
+ Q_h = self.Q(h)
+ K_h = self.K(h)
+ E = self.E(e)
+
+ if self.full_graph:
+ Q_2h = self.Q_2(h)
+ K_2h = self.K_2(h)
+ E_2 = self.E_2(e)
+
+ V_h = self.V(h)
+
+
+ # Reshaping into [num_nodes, num_heads, feat_dim] to
+ # get projections for multi-head attention
+ g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E'] = E.view(-1, self.num_heads, self.out_dim)
+
+
+ if self.full_graph:
+ g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim)
+ g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim)
+ g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim)
+
+ g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
+
+ self.propagate_attention(g)
+
+ h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6))
+
+ return h_out
+
+
+class SAN_GT_LSPE_Layer(nn.Module):
+ """
+ Param:
+ """
+ def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0,
+ layer_norm=False, batch_norm=True, residual=True, use_bias=False):
+ super().__init__()
+
+ self.in_channels = in_dim
+ self.out_channels = out_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.residual = residual
+ self.layer_norm = layer_norm
+ self.batch_norm = batch_norm
+
+ self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, attention_for="h")
+ self.attention_p = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads,
+ full_graph, use_bias, attention_for="p")
+
+ self.O_h = nn.Linear(out_dim, out_dim)
+ self.O_p = nn.Linear(out_dim, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm1_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm1_h = nn.BatchNorm1d(out_dim)
+
+ # FFN for h
+ self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2)
+ self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim)
+
+ if self.layer_norm:
+ self.layer_norm2_h = nn.LayerNorm(out_dim)
+
+ if self.batch_norm:
+ self.batch_norm2_h = nn.BatchNorm1d(out_dim)
+
+
+ def forward(self, g, h, p, e, snorm_n):
+ h_in1 = h # for first residual connection
+ p_in1 = p # for first residual connection
+
+ # [START] For calculation of h -----------------------------------------------------------------
+
+ # multi-head attention out
+ h_attn_out = self.attention_h(g, h, p, e)
+
+ #Concat multi-head outputs
+ h = h_attn_out.view(-1, self.out_channels)
+
+ h = F.dropout(h, self.dropout, training=self.training)
+
+ h = self.O_h(h)
+
+ if self.residual:
+ h = h_in1 + h # residual connection
+
+ # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm1_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm1_h(h)
+
+ h_in2 = h # for second residual connection
+
+ # FFN for h
+ h = self.FFN_h_layer1(h)
+ h = F.relu(h)
+ h = F.dropout(h, self.dropout, training=self.training)
+ h = self.FFN_h_layer2(h)
+
+ if self.residual:
+ h = h_in2 + h # residual connection
+
+ # GN from benchmarking-gnns-v1
+ # h = h * snorm_n
+
+ if self.layer_norm:
+ h = self.layer_norm2_h(h)
+
+ if self.batch_norm:
+ h = self.batch_norm2_h(h)
+
+ # [END] For calculation of h -----------------------------------------------------------------
+
+
+ # [START] For calculation of p -----------------------------------------------------------------
+
+ # multi-head attention out
+ p_attn_out = self.attention_p(g, None, p, e)
+
+ #Concat multi-head outputs
+ p = p_attn_out.view(-1, self.out_channels)
+
+ p = F.dropout(p, self.dropout, training=self.training)
+
+ p = self.O_p(p)
+
+ p = torch.tanh(p)
+
+ if self.residual:
+ p = p_in1 + p # residual connection
+
+ # [END] For calculation of p -----------------------------------------------------------------
+
+ return h, p
+
+ def __repr__(self):
+ return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
+ self.in_channels,
+ self.out_channels, self.num_heads, self.residual)
\ No newline at end of file
diff --git a/main_OGBMOL_graph_classification.py b/main_OGBMOL_graph_classification.py
new file mode 100644
index 0000000..e12eedc
--- /dev/null
+++ b/main_OGBMOL_graph_classification.py
@@ -0,0 +1,504 @@
+
+
+
+
+
+"""
+ IMPORTING LIBS
+"""
+import dgl
+
+import numpy as np
+import os
+import socket
+import time
+import random
+import glob
+import argparse, json
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+from tensorboardX import SummaryWriter
+from tqdm import tqdm
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+class DotDict(dict):
+ def __init__(self, **kwds):
+ self.update(kwds)
+ self.__dict__ = self
+
+
+
+
+
+
+"""
+ IMPORTING CUSTOM MODULES/METHODS
+"""
+
+from nets.OGBMOL_graph_classification.load_net import gnn_model # import GNNs
+from data.data import LoadData # import dataset
+
+
+
+
+"""
+ GPU Setup
+"""
+def gpu_setup(use_gpu, gpu_id):
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+
+ if torch.cuda.is_available() and use_gpu:
+ print('cuda available with GPU:',torch.cuda.get_device_name(0))
+ device = torch.device("cuda")
+ else:
+ print('cuda not available')
+ device = torch.device("cpu")
+ return device
+
+
+
+
+
+
+
+"""
+ VIEWING MODEL CONFIG AND PARAMS
+"""
+def view_model_param(MODEL_NAME, net_params):
+ model = gnn_model(MODEL_NAME, net_params)
+ total_param = 0
+ print("MODEL DETAILS:\n")
+ # print(model)
+ for param in model.parameters():
+ # print(param.data.size())
+ total_param += np.prod(list(param.data.size()))
+ print('MODEL/Total parameters:', MODEL_NAME, total_param)
+ return total_param
+
+
+def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):
+ t0 = time.time()
+ per_epoch_time = []
+
+ DATASET_NAME = dataset.name
+
+ if net_params['pe_init'] == 'lap_pe':
+ tt = time.time()
+ print("[!] -LapPE: Initializing graph positional encoding with Laplacian PE.")
+ dataset._add_lap_positional_encodings(net_params['pos_enc_dim'])
+ print("[!] Time taken: ", time.time()-tt)
+ elif net_params['pe_init'] == 'rand_walk':
+ tt = time.time()
+ print("[!] -LSPE: Initializing graph positional encoding with rand walk features.")
+ dataset._init_positional_encodings(net_params['pos_enc_dim'], net_params['pe_init'])
+ print("[!] Time taken: ", time.time()-tt)
+
+ tt = time.time()
+ print("[!] -LSPE (For viz later): Adding lapeigvecs to key 'eigvec' for every graph.")
+ dataset._add_eig_vecs(net_params['pos_enc_dim'])
+ print("[!] Time taken: ", time.time()-tt)
+
+ if MODEL_NAME in ['SAN', 'GraphiT']:
+ if net_params['full_graph']:
+ st = time.time()
+ print("[!] Adding full graph connectivity..")
+ dataset._make_full_graph() if MODEL_NAME == 'SAN' else dataset._make_full_graph((net_params['p_steps'], net_params['gamma']))
+ print('Time taken to add full graph connectivity: ',time.time()-st)
+
+ trainset, valset, testset = dataset.train, dataset.val, dataset.test
+
+ evaluator = dataset.evaluator
+
+ root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir = dirs
+ device = net_params['device']
+
+ # Write the network and optimization hyper-parameters in folder config/
+ with open(write_config_file + '.txt', 'w') as f:
+ f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
+
+ log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
+ writer = SummaryWriter(log_dir=log_dir)
+
+ # setting seeds
+ random.seed(params['seed'])
+ np.random.seed(params['seed'])
+ torch.manual_seed(params['seed'])
+ if device.type == 'cuda':
+ torch.cuda.manual_seed(params['seed'])
+ torch.cuda.manual_seed_all(params['seed'])
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ print("Training Graphs: ", len(trainset))
+ print("Validation Graphs: ", len(valset))
+ print("Test Graphs: ", len(testset))
+
+ model = gnn_model(MODEL_NAME, net_params)
+ model = model.to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
+ factor=params['lr_reduce_factor'],
+ patience=params['lr_schedule_patience'],
+ verbose=True)
+
+ epoch_train_losses, epoch_val_losses = [], []
+ epoch_train_accs, epoch_val_accs, epoch_test_accs = [], [], []
+
+ # import train functions for all GNNs
+ from train.train_OGBMOL_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network
+
+ train_loader = DataLoader(trainset, num_workers=4, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate, pin_memory=True)
+ val_loader = DataLoader(valset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True)
+ test_loader = DataLoader(testset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True)
+
+ # At any point you can hit Ctrl + C to break out of training early.
+ try:
+ with tqdm(range(params['epochs'])) as t:
+ for epoch in t:
+
+ t.set_description('Epoch %d' % epoch)
+
+ start = time.time()
+
+ epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch, evaluator)
+
+ epoch_val_loss, epoch_val_acc, __ = evaluate_network(model, device, val_loader, epoch, evaluator)
+ _, epoch_test_acc, __ = evaluate_network(model, device, test_loader, epoch, evaluator)
+ del __
+
+ epoch_train_losses.append(epoch_train_loss)
+ epoch_val_losses.append(epoch_val_loss)
+ epoch_train_accs.append(epoch_train_acc)
+ epoch_val_accs.append(epoch_val_acc)
+ epoch_test_accs.append(epoch_test_acc)
+
+ writer.add_scalar('train/_loss', epoch_train_loss, epoch)
+ writer.add_scalar('val/_loss', epoch_val_loss, epoch)
+ writer.add_scalar('train/_avg_prec', epoch_train_acc, epoch)
+ writer.add_scalar('val/_avg_prec', epoch_val_acc, epoch)
+ writer.add_scalar('test/_avg_prec', epoch_test_acc, epoch)
+ writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
+
+ if dataset.name == "ogbg-moltox21":
+ t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
+ train_loss=epoch_train_loss, val_loss=epoch_val_loss,
+ train_AUC=epoch_train_acc, val_AUC=epoch_val_acc,
+ test_AUC=epoch_test_acc)
+ elif dataset.name == "ogbg-molpcba":
+ t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
+ train_loss=epoch_train_loss, val_loss=epoch_val_loss,
+ train_AP=epoch_train_acc, val_AP=epoch_val_acc,
+ test_AP=epoch_test_acc)
+
+ per_epoch_time.append(time.time()-start)
+
+ # Saving checkpoint
+ ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
+ if not os.path.exists(ckpt_dir):
+ os.makedirs(ckpt_dir)
+ torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
+
+ files = glob.glob(ckpt_dir + '/*.pkl')
+ for file in files:
+ epoch_nb = file.split('_')[-1]
+ epoch_nb = int(epoch_nb.split('.')[0])
+ if epoch_nb < epoch-1:
+ os.remove(file)
+
+ scheduler.step(epoch_val_loss)
+
+ if optimizer.param_groups[0]['lr'] < params['min_lr']:
+ print("\n!! LR EQUAL TO MIN LR SET.")
+ break
+
+ # Stop training after params['max_time'] hours
+ if time.time()-t0 > params['max_time']*3600:
+ print('-' * 89)
+ print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time']))
+ break
+
+ except KeyboardInterrupt:
+ print('-' * 89)
+ print('Exiting from training early because of KeyboardInterrupt')
+
+ # ___, __, g_outs_train = evaluate_network(model, device, train_loader, epoch, evaluator)
+ ___, __, g_outs_test = evaluate_network(model, device, test_loader, epoch, evaluator)
+ del ___
+ del __
+
+ # OGB: Test scores at best val epoch
+ epoch_best = epoch_val_accs.index(max(epoch_val_accs))
+
+ test_acc = epoch_test_accs[epoch_best]
+ train_acc = epoch_train_accs[epoch_best]
+ val_acc = epoch_val_accs[epoch_best]
+
+ if dataset.name == "ogbg-moltox21":
+ print("Test AUC: {:.4f}".format(test_acc))
+ print("Train AUC: {:.4f}".format(train_acc))
+ print("Val AUC: {:.4f}".format(val_acc))
+ elif dataset.name == "ogbg-molpcba":
+ print("Test Avg Precision: {:.4f}".format(test_acc))
+ print("Train Avg Precision: {:.4f}".format(train_acc))
+ print("Convergence Time (Epochs): {:.4f}".format(epoch))
+ print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-t0))
+ print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
+
+ if net_params['pe_init'] == 'rand_walk' and g_outs_test is not None:
+ # Visualize actual and predicted/learned eigenvecs
+ from utils.plot_util import plot_graph_eigvec
+ if not os.path.exists(viz_dir):
+ os.makedirs(viz_dir)
+
+ sample_graph_ids = [153,103,123]
+
+ for f_idx, graph_id in enumerate(sample_graph_ids):
+
+ # Test graphs
+ g_dgl = g_outs_test[graph_id]
+
+ f = plt.figure(f_idx, figsize=(12,6))
+
+ plt1 = f.add_subplot(121)
+ plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True)
+
+ plt2 = f.add_subplot(122)
+ plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True)
+
+ f.savefig(viz_dir+'/test'+str(graph_id)+'.jpg')
+
+ writer.close()
+
+ """
+ Write the results in out_dir/results folder
+ """
+ if dataset.name == "ogbg-moltox21":
+ with open(write_file_name + '.txt', 'w') as f:
+ f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
+ FINAL RESULTS\nTEST AUC: {:.4f}\nTRAIN AUC: {:.4f}\nVAL AUC: {:.4f}\n\n
+ Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
+ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
+ test_acc, train_acc, val_acc, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time)))
+ elif dataset.name == "ogbg-molpcba":
+ with open(write_file_name + '.txt', 'w') as f:
+ f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
+ FINAL RESULTS\nTEST AVG PRECISION: {:.4f}\nTRAIN AVG PRECISION: {:.4f}\nVAL AVG PRECISION: {:.4f}\n\n
+ Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
+ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
+ test_acc, train_acc, val_acc, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time)))
+
+
+
+
+def main():
+ """
+ USER CONTROLS
+ """
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
+ parser.add_argument('--gpu_id', help="Please give a value for gpu id")
+ parser.add_argument('--model', help="Please give a value for model name")
+ parser.add_argument('--dataset', help="Please give a value for dataset name")
+ parser.add_argument('--out_dir', help="Please give a value for out_dir")
+ parser.add_argument('--seed', help="Please give a value for seed")
+ parser.add_argument('--epochs', help="Please give a value for epochs")
+ parser.add_argument('--batch_size', help="Please give a value for batch_size")
+ parser.add_argument('--init_lr', help="Please give a value for init_lr")
+ parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
+ parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
+ parser.add_argument('--min_lr', help="Please give a value for min_lr")
+ parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
+ parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")
+ parser.add_argument('--L', help="Please give a value for L")
+ parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
+ parser.add_argument('--out_dim', help="Please give a value for out_dim")
+ parser.add_argument('--residual', help="Please give a value for residual")
+ parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
+ parser.add_argument('--readout', help="Please give a value for readout")
+ parser.add_argument('--kernel', help="Please give a value for kernel")
+ parser.add_argument('--n_heads', help="Please give a value for n_heads")
+ parser.add_argument('--gated', help="Please give a value for gated")
+ parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
+ parser.add_argument('--dropout', help="Please give a value for dropout")
+ parser.add_argument('--layer_norm', help="Please give a value for layer_norm")
+ parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
+ parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator")
+ parser.add_argument('--data_mode', help="Please give a value for data_mode")
+ parser.add_argument('--num_pool', help="Please give a value for num_pool")
+ parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block")
+ parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim")
+ parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio")
+ parser.add_argument('--linkpred', help="Please give a value for linkpred")
+ parser.add_argument('--cat', help="Please give a value for cat")
+ parser.add_argument('--self_loop', help="Please give a value for self_loop")
+ parser.add_argument('--max_time', help="Please give a value for max_time")
+ parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim")
+ parser.add_argument('--alpha_loss', help="Please give a value for alpha_loss")
+ parser.add_argument('--lambda_loss', help="Please give a value for lambda_loss")
+ parser.add_argument('--pe_init', help="Please give a value for pe_init")
+ args = parser.parse_args()
+ with open(args.config) as f:
+ config = json.load(f)
+ # device
+ if args.gpu_id is not None:
+ config['gpu']['id'] = int(args.gpu_id)
+ config['gpu']['use'] = True
+ device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
+ # model, dataset, out_dir
+ if args.model is not None:
+ MODEL_NAME = args.model
+ else:
+ MODEL_NAME = config['model']
+ if args.dataset is not None:
+ DATASET_NAME = args.dataset
+ else:
+ DATASET_NAME = config['dataset']
+ dataset = LoadData(DATASET_NAME)
+ if args.out_dir is not None:
+ out_dir = args.out_dir
+ else:
+ out_dir = config['out_dir']
+ # parameters
+ params = config['params']
+ if args.seed is not None:
+ params['seed'] = int(args.seed)
+ if args.epochs is not None:
+ params['epochs'] = int(args.epochs)
+ if args.batch_size is not None:
+ params['batch_size'] = int(args.batch_size)
+ if args.init_lr is not None:
+ params['init_lr'] = float(args.init_lr)
+ if args.lr_reduce_factor is not None:
+ params['lr_reduce_factor'] = float(args.lr_reduce_factor)
+ if args.lr_schedule_patience is not None:
+ params['lr_schedule_patience'] = int(args.lr_schedule_patience)
+ if args.min_lr is not None:
+ params['min_lr'] = float(args.min_lr)
+ if args.weight_decay is not None:
+ params['weight_decay'] = float(args.weight_decay)
+ if args.print_epoch_interval is not None:
+ params['print_epoch_interval'] = int(args.print_epoch_interval)
+ if args.max_time is not None:
+ params['max_time'] = float(args.max_time)
+ # network parameters
+ net_params = config['net_params']
+ net_params['device'] = device
+ net_params['gpu_id'] = config['gpu']['id']
+ net_params['batch_size'] = params['batch_size']
+ if args.L is not None:
+ net_params['L'] = int(args.L)
+ if args.hidden_dim is not None:
+ net_params['hidden_dim'] = int(args.hidden_dim)
+ if args.out_dim is not None:
+ net_params['out_dim'] = int(args.out_dim)
+ if args.residual is not None:
+ net_params['residual'] = True if args.residual=='True' else False
+ if args.edge_feat is not None:
+ net_params['edge_feat'] = True if args.edge_feat=='True' else False
+ if args.readout is not None:
+ net_params['readout'] = args.readout
+ if args.kernel is not None:
+ net_params['kernel'] = int(args.kernel)
+ if args.n_heads is not None:
+ net_params['n_heads'] = int(args.n_heads)
+ if args.gated is not None:
+ net_params['gated'] = True if args.gated=='True' else False
+ if args.in_feat_dropout is not None:
+ net_params['in_feat_dropout'] = float(args.in_feat_dropout)
+ if args.dropout is not None:
+ net_params['dropout'] = float(args.dropout)
+ if args.layer_norm is not None:
+ net_params['layer_norm'] = True if args.layer_norm=='True' else False
+ if args.batch_norm is not None:
+ net_params['batch_norm'] = True if args.batch_norm=='True' else False
+ if args.sage_aggregator is not None:
+ net_params['sage_aggregator'] = args.sage_aggregator
+ if args.data_mode is not None:
+ net_params['data_mode'] = args.data_mode
+ if args.num_pool is not None:
+ net_params['num_pool'] = int(args.num_pool)
+ if args.gnn_per_block is not None:
+ net_params['gnn_per_block'] = int(args.gnn_per_block)
+ if args.embedding_dim is not None:
+ net_params['embedding_dim'] = int(args.embedding_dim)
+ if args.pool_ratio is not None:
+ net_params['pool_ratio'] = float(args.pool_ratio)
+ if args.linkpred is not None:
+ net_params['linkpred'] = True if args.linkpred=='True' else False
+ if args.cat is not None:
+ net_params['cat'] = True if args.cat=='True' else False
+ if args.self_loop is not None:
+ net_params['self_loop'] = True if args.self_loop=='True' else False
+ if args.pos_enc_dim is not None:
+ net_params['pos_enc_dim'] = int(args.pos_enc_dim)
+ if args.alpha_loss is not None:
+ net_params['alpha_loss'] = float(args.alpha_loss)
+ if args.lambda_loss is not None:
+ net_params['lambda_loss'] = float(args.lambda_loss)
+ if args.pe_init is not None:
+ net_params['pe_init'] = args.pe_init
+
+
+
+ # OGBMOL*
+ num_classes = dataset.dataset.num_tasks # provided by OGB dataset class
+ net_params['n_classes'] = num_classes
+
+ if MODEL_NAME == 'PNA':
+ D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g, label in
+ dataset.train])
+ net_params['avg_d'] = dict(lin=torch.mean(D),
+ exp=torch.mean(torch.exp(torch.div(1, D)) - 1),
+ log=torch.mean(torch.log(D + 1)))
+
+ root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ viz_dir = out_dir + 'viz/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir
+
+ if not os.path.exists(out_dir + 'results'):
+ os.makedirs(out_dir + 'results')
+
+ if not os.path.exists(out_dir + 'configs'):
+ os.makedirs(out_dir + 'configs')
+
+ net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
+ train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
+
+
+
+
+
+
+
+
+main()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/main_ZINC_graph_regression.py b/main_ZINC_graph_regression.py
new file mode 100644
index 0000000..91d028e
--- /dev/null
+++ b/main_ZINC_graph_regression.py
@@ -0,0 +1,453 @@
+
+
+
+
+
+"""
+ IMPORTING LIBS
+"""
+import dgl
+
+import numpy as np
+import os
+import socket
+import time
+import random
+import glob
+import argparse, json
+import pickle
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+from tensorboardX import SummaryWriter
+from tqdm import tqdm
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+
+class DotDict(dict):
+ def __init__(self, **kwds):
+ self.update(kwds)
+ self.__dict__ = self
+
+
+
+
+
+
+"""
+ IMPORTING CUSTOM MODULES/METHODS
+"""
+from nets.ZINC_graph_regression.load_net import gnn_model # import all GNNS
+from data.data import LoadData # import dataset
+
+
+
+
+"""
+ GPU Setup
+"""
+def gpu_setup(use_gpu, gpu_id):
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+
+ if torch.cuda.is_available() and use_gpu:
+ print('cuda available with GPU:',torch.cuda.get_device_name(0))
+ device = torch.device("cuda")
+ else:
+ print('cuda not available')
+ device = torch.device("cpu")
+ return device
+
+
+
+
+
+
+
+
+"""
+ VIEWING MODEL CONFIG AND PARAMS
+"""
+def view_model_param(MODEL_NAME, net_params):
+ model = gnn_model(MODEL_NAME, net_params)
+ total_param = 0
+ print("MODEL DETAILS:\n")
+ #print(model)
+ for param in model.parameters():
+ # print(param.data.size())
+ total_param += np.prod(list(param.data.size()))
+ print('MODEL/Total parameters:', MODEL_NAME, total_param)
+ return total_param
+
+
+"""
+ TRAINING CODE
+"""
+
+def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):
+ t0 = time.time()
+ per_epoch_time = []
+
+ DATASET_NAME = dataset.name
+
+ if net_params['pe_init'] == 'lap_pe':
+ tt = time.time()
+ print("[!] -LapPE: Initializing graph positional encoding with Laplacian PE.")
+ dataset._add_lap_positional_encodings(net_params['pos_enc_dim'])
+ print("[!] Time taken: ", time.time()-tt)
+ elif net_params['pe_init'] == 'rand_walk':
+ tt = time.time()
+ print("[!] -LSPE: Initializing graph positional encoding with rand walk features.")
+ dataset._init_positional_encodings(net_params['pos_enc_dim'], net_params['pe_init'])
+ print("[!] Time taken: ", time.time()-tt)
+
+ tt = time.time()
+ print("[!] -LSPE (For viz later): Adding lapeigvecs to key 'eigvec' for every graph.")
+ dataset._add_eig_vecs(net_params['pos_enc_dim'])
+ print("[!] Time taken: ", time.time()-tt)
+
+ if MODEL_NAME in ['SAN', 'GraphiT']:
+ if net_params['full_graph']:
+ st = time.time()
+ print("[!] Adding full graph connectivity..")
+ dataset._make_full_graph() if MODEL_NAME == 'SAN' else dataset._make_full_graph((net_params['p_steps'], net_params['gamma']))
+ print('Time taken to add full graph connectivity: ',time.time()-st)
+
+ trainset, valset, testset = dataset.train, dataset.val, dataset.test
+
+ root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir = dirs
+ device = net_params['device']
+
+ # Write the network and optimization hyper-parameters in folder config/
+ with open(write_config_file + '.txt', 'w') as f:
+ f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
+
+ log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
+ writer = SummaryWriter(log_dir=log_dir)
+
+ # setting seeds
+ random.seed(params['seed'])
+ np.random.seed(params['seed'])
+ torch.manual_seed(params['seed'])
+ if device.type == 'cuda':
+ torch.cuda.manual_seed(params['seed'])
+ torch.cuda.manual_seed_all(params['seed'])
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ print("Training Graphs: ", len(trainset))
+ print("Validation Graphs: ", len(valset))
+ print("Test Graphs: ", len(testset))
+
+ model = gnn_model(MODEL_NAME, net_params)
+ model = model.to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
+ factor=params['lr_reduce_factor'],
+ patience=params['lr_schedule_patience'],
+ verbose=True)
+
+ epoch_train_losses, epoch_val_losses = [], []
+ epoch_train_MAEs, epoch_val_MAEs = [], []
+
+ # import train functions for all GNNs
+ from train.train_ZINC_graph_regression import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network
+
+ train_loader = DataLoader(trainset, num_workers=4, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate)
+ val_loader = DataLoader(valset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate)
+ test_loader = DataLoader(testset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate)
+
+ # At any point you can hit Ctrl + C to break out of training early.
+ try:
+ with tqdm(range(params['epochs'])) as t:
+ for epoch in t:
+
+ t.set_description('Epoch %d' % epoch)
+
+ start = time.time()
+
+ epoch_train_loss, epoch_train_mae, optimizer = train_epoch(model, optimizer, device, train_loader, epoch)
+
+ epoch_val_loss, epoch_val_mae, __ = evaluate_network(model, device, val_loader, epoch)
+ epoch_test_loss, epoch_test_mae, __ = evaluate_network(model, device, test_loader, epoch)
+ del __
+
+ epoch_train_losses.append(epoch_train_loss)
+ epoch_val_losses.append(epoch_val_loss)
+ epoch_train_MAEs.append(epoch_train_mae)
+ epoch_val_MAEs.append(epoch_val_mae)
+
+ writer.add_scalar('train/_loss', epoch_train_loss, epoch)
+ writer.add_scalar('val/_loss', epoch_val_loss, epoch)
+ writer.add_scalar('train/_mae', epoch_train_mae, epoch)
+ writer.add_scalar('val/_mae', epoch_val_mae, epoch)
+ writer.add_scalar('test/_mae', epoch_test_mae, epoch)
+ writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
+
+
+ t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
+ train_loss=epoch_train_loss, val_loss=epoch_val_loss,
+ train_MAE=epoch_train_mae, val_MAE=epoch_val_mae,
+ test_MAE=epoch_test_mae)
+
+
+ per_epoch_time.append(time.time()-start)
+
+ # Saving checkpoint
+ ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
+ if not os.path.exists(ckpt_dir):
+ os.makedirs(ckpt_dir)
+ torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
+
+ files = glob.glob(ckpt_dir + '/*.pkl')
+ for file in files:
+ epoch_nb = file.split('_')[-1]
+ epoch_nb = int(epoch_nb.split('.')[0])
+ if epoch_nb < epoch-1:
+ os.remove(file)
+
+ scheduler.step(epoch_val_loss)
+
+ if optimizer.param_groups[0]['lr'] < params['min_lr']:
+ print("\n!! LR EQUAL TO MIN LR SET.")
+ break
+
+ # Stop training after params['max_time'] hours
+ if time.time()-t0 > params['max_time']*3600:
+ print('-' * 89)
+ print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time']))
+ break
+
+ except KeyboardInterrupt:
+ print('-' * 89)
+ print('Exiting from training early because of KeyboardInterrupt')
+
+ test_loss_lapeig, test_mae, g_outs_test = evaluate_network(model, device, test_loader, epoch)
+ train_loss_lapeig, train_mae, g_outs_train = evaluate_network(model, device, train_loader, epoch)
+
+ print("Test MAE: {:.4f}".format(test_mae))
+ print("Train MAE: {:.4f}".format(train_mae))
+ print("Convergence Time (Epochs): {:.4f}".format(epoch))
+ print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-t0))
+ print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
+
+
+ if net_params['pe_init'] == 'rand_walk':
+ # Visualize actual and predicted/learned eigenvecs
+ from utils.plot_util import plot_graph_eigvec
+ if not os.path.exists(viz_dir):
+ os.makedirs(viz_dir)
+
+ sample_graph_ids = [15,25,45]
+
+ for f_idx, graph_id in enumerate(sample_graph_ids):
+
+ # Test graphs
+ g_dgl = g_outs_test[graph_id]
+
+ f = plt.figure(f_idx, figsize=(12,6))
+
+ plt1 = f.add_subplot(121)
+ plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True)
+
+ plt2 = f.add_subplot(122)
+ plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True)
+
+ f.savefig(viz_dir+'/test'+str(graph_id)+'.jpg')
+
+ # Train graphs
+ g_dgl = g_outs_train[graph_id]
+
+ f = plt.figure(f_idx, figsize=(12,6))
+
+ plt1 = f.add_subplot(121)
+ plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True)
+
+ plt2 = f.add_subplot(122)
+ plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True)
+
+ f.savefig(viz_dir+'/train'+str(graph_id)+'.jpg')
+
+ writer.close()
+
+ """
+ Write the results in out_dir/results folder
+ """
+ with open(write_file_name + '.txt', 'w') as f:
+ f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
+ FINAL RESULTS\nTEST MAE: {:.4f}\nTRAIN MAE: {:.4f}\n\n
+ Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
+ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
+ test_mae, train_mae, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time)))
+
+
+
+
+
+def main():
+ """
+ USER CONTROLS
+ """
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
+ parser.add_argument('--gpu_id', help="Please give a value for gpu id")
+ parser.add_argument('--model', help="Please give a value for model name")
+ parser.add_argument('--dataset', help="Please give a value for dataset name")
+ parser.add_argument('--out_dir', help="Please give a value for out_dir")
+ parser.add_argument('--seed', help="Please give a value for seed")
+ parser.add_argument('--epochs', help="Please give a value for epochs")
+ parser.add_argument('--batch_size', help="Please give a value for batch_size")
+ parser.add_argument('--init_lr', help="Please give a value for init_lr")
+ parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
+ parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
+ parser.add_argument('--min_lr', help="Please give a value for min_lr")
+ parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
+ parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")
+ parser.add_argument('--L', help="Please give a value for L")
+ parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
+ parser.add_argument('--out_dim', help="Please give a value for out_dim")
+ parser.add_argument('--residual', help="Please give a value for residual")
+ parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
+ parser.add_argument('--readout', help="Please give a value for readout")
+ parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
+ parser.add_argument('--dropout', help="Please give a value for dropout")
+ parser.add_argument('--layer_norm', help="Please give a value for layer_norm")
+ parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
+ parser.add_argument('--max_time', help="Please give a value for max_time")
+ parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim")
+ parser.add_argument('--pos_enc', help="Please give a value for pos_enc")
+ parser.add_argument('--alpha_loss', help="Please give a value for alpha_loss")
+ parser.add_argument('--lambda_loss', help="Please give a value for lambda_loss")
+ parser.add_argument('--pe_init', help="Please give a value for pe_init")
+ args = parser.parse_args()
+ with open(args.config) as f:
+ config = json.load(f)
+
+ # device
+ if args.gpu_id is not None:
+ config['gpu']['id'] = int(args.gpu_id)
+ config['gpu']['use'] = True
+ device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
+ # model, dataset, out_dir
+ if args.model is not None:
+ MODEL_NAME = args.model
+ else:
+ MODEL_NAME = config['model']
+ if args.dataset is not None:
+ DATASET_NAME = args.dataset
+ else:
+ DATASET_NAME = config['dataset']
+ dataset = LoadData(DATASET_NAME)
+ if args.out_dir is not None:
+ out_dir = args.out_dir
+ else:
+ out_dir = config['out_dir']
+ # parameters
+ params = config['params']
+ if args.seed is not None:
+ params['seed'] = int(args.seed)
+ if args.epochs is not None:
+ params['epochs'] = int(args.epochs)
+ if args.batch_size is not None:
+ params['batch_size'] = int(args.batch_size)
+ if args.init_lr is not None:
+ params['init_lr'] = float(args.init_lr)
+ if args.lr_reduce_factor is not None:
+ params['lr_reduce_factor'] = float(args.lr_reduce_factor)
+ if args.lr_schedule_patience is not None:
+ params['lr_schedule_patience'] = int(args.lr_schedule_patience)
+ if args.min_lr is not None:
+ params['min_lr'] = float(args.min_lr)
+ if args.weight_decay is not None:
+ params['weight_decay'] = float(args.weight_decay)
+ if args.print_epoch_interval is not None:
+ params['print_epoch_interval'] = int(args.print_epoch_interval)
+ if args.max_time is not None:
+ params['max_time'] = float(args.max_time)
+ # network parameters
+ net_params = config['net_params']
+ net_params['device'] = device
+ net_params['gpu_id'] = config['gpu']['id']
+ net_params['batch_size'] = params['batch_size']
+ if args.L is not None:
+ net_params['L'] = int(args.L)
+ if args.hidden_dim is not None:
+ net_params['hidden_dim'] = int(args.hidden_dim)
+ if args.out_dim is not None:
+ net_params['out_dim'] = int(args.out_dim)
+ if args.residual is not None:
+ net_params['residual'] = True if args.residual=='True' else False
+ if args.edge_feat is not None:
+ net_params['edge_feat'] = True if args.edge_feat=='True' else False
+ if args.readout is not None:
+ net_params['readout'] = args.readout
+ if args.in_feat_dropout is not None:
+ net_params['in_feat_dropout'] = float(args.in_feat_dropout)
+ if args.dropout is not None:
+ net_params['dropout'] = float(args.dropout)
+ if args.layer_norm is not None:
+ net_params['layer_norm'] = True if args.layer_norm=='True' else False
+ if args.batch_norm is not None:
+ net_params['batch_norm'] = True if args.batch_norm=='True' else False
+ if args.pos_enc is not None:
+ net_params['pos_enc'] = True if args.pos_enc=='True' else False
+ if args.pos_enc_dim is not None:
+ net_params['pos_enc_dim'] = int(args.pos_enc_dim)
+ if args.alpha_loss is not None:
+ net_params['alpha_loss'] = float(args.alpha_loss)
+ if args.lambda_loss is not None:
+ net_params['lambda_loss'] = float(args.lambda_loss)
+ if args.pe_init is not None:
+ net_params['pe_init'] = args.pe_init
+
+
+ # ZINC
+ net_params['num_atom_type'] = dataset.num_atom_type
+ net_params['num_bond_type'] = dataset.num_bond_type
+
+ if MODEL_NAME == 'PNA':
+ D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g in
+ dataset.train.graph_lists])
+ net_params['avg_d'] = dict(lin=torch.mean(D),
+ exp=torch.mean(torch.exp(torch.div(1, D)) - 1),
+ log=torch.mean(torch.log(D + 1)))
+
+ root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ viz_dir = out_dir + 'viz/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
+ dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir
+
+ if not os.path.exists(out_dir + 'results'):
+ os.makedirs(out_dir + 'results')
+
+ if not os.path.exists(out_dir + 'configs'):
+ os.makedirs(out_dir + 'configs')
+
+ net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
+ train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
+
+
+
+
+
+
+
+
+main()
+
+
+
+
+
diff --git a/nets/OGBMOL_graph_classification/gatedgcn_net.py b/nets/OGBMOL_graph_classification/gatedgcn_net.py
new file mode 100644
index 0000000..3bb2c9c
--- /dev/null
+++ b/nets/OGBMOL_graph_classification/gatedgcn_net.py
@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+
+"""
+ GatedGCN and GatedGCN-LSPE
+
+"""
+
+from layers.gatedgcn_layer import GatedGCNLayer
+from layers.gatedgcn_lspe_layer import GatedGCNLSPELayer
+from layers.mlp_readout_layer import MLPReadout
+
+
+class GatedGCNNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+ hidden_dim = net_params['hidden_dim']
+ out_dim = net_params['out_dim']
+ n_classes = net_params['n_classes']
+ dropout = net_params['dropout']
+ n_layers = net_params['L']
+ self.readout = net_params['readout']
+ self.batch_norm = net_params['batch_norm']
+ self.residual = net_params['residual']
+ self.device = net_params['device']
+ self.pe_init = net_params['pe_init']
+ self.n_classes = net_params['n_classes']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk', 'lap_pe']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim)
+
+ self.atom_encoder = AtomEncoder(hidden_dim)
+ self.bond_encoder = BondEncoder(hidden_dim)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ GatedGCNLSPELayer(hidden_dim, hidden_dim, dropout, self.batch_norm, self.residual)
+ for _ in range(n_layers-1) ])
+ self.layers.append(GatedGCNLSPELayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual))
+ else:
+ # NoPE or LapPE
+ self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, self.batch_norm, self.residual)
+ for _ in range(n_layers-1) ])
+ self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual))
+
+ self.MLP_layer = MLPReadout(out_dim, n_classes)
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = self.atom_encoder(h)
+ e = self.bond_encoder(e)
+
+ if self.pe_init in ['rand_walk', 'lap_pe']:
+ p = self.embedding_p(p)
+
+ if self.pe_init == 'lap_pe':
+ h = h + p
+ p = None
+
+ for conv in self.layers:
+ h, p, e = conv(g, h, p, e, snorm_n)
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = self.p_out(p)
+ g.ndata['p'] = p
+
+ if self.use_lapeig_loss:
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ if self.pe_init == 'rand_walk':
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ if self.n_classes == 128:
+ return_g = None # not passing PCBA graphs due to memory
+ else:
+ return_g = g
+
+ return self.MLP_layer(hg), return_g
+
+ def loss(self, pred, labels):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/OGBMOL_graph_classification/graphit_net.py b/nets/OGBMOL_graph_classification/graphit_net.py
new file mode 100644
index 0000000..ec94cd5
--- /dev/null
+++ b/nets/OGBMOL_graph_classification/graphit_net.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+
+"""
+ GraphiT-GT and GraphiT-GT-LSPE
+
+"""
+
+from layers.graphit_gt_layer import GraphiT_GT_Layer
+from layers.graphit_gt_lspe_layer import GraphiT_GT_LSPE_Layer
+from layers.mlp_readout_layer import MLPReadout
+
+class GraphiTNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+
+ full_graph = net_params['full_graph']
+ gamma = net_params['gamma']
+ self.adaptive_edge_PE = net_params['adaptive_edge_PE']
+
+ GT_layers = net_params['L']
+ GT_hidden_dim = net_params['hidden_dim']
+ GT_out_dim = net_params['out_dim']
+ GT_n_heads = net_params['n_heads']
+
+ self.residual = net_params['residual']
+ self.readout = net_params['readout']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+
+ self.readout = net_params['readout']
+ self.layer_norm = net_params['layer_norm']
+ self.batch_norm = net_params['batch_norm']
+
+ n_classes = net_params['n_classes']
+ self.device = net_params['device']
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim)
+
+ self.embedding_h = AtomEncoder(GT_hidden_dim)
+ self.embedding_e = BondEncoder(GT_hidden_dim)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ])
+ self.layers.append(GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE))
+ else:
+ # NoPE
+ self.layers = nn.ModuleList([ GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ])
+ self.layers.append(GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE))
+
+ self.MLP_layer = MLPReadout(GT_out_dim, n_classes)
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = self.embedding_h(h)
+ e = self.embedding_e(e)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ for conv in self.layers:
+ h, p = conv(g, h, p, e, snorm_n)
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = self.p_out(p)
+ g.ndata['p'] = p
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, pred, labels):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/OGBMOL_graph_classification/load_net.py b/nets/OGBMOL_graph_classification/load_net.py
new file mode 100644
index 0000000..9c8c3b5
--- /dev/null
+++ b/nets/OGBMOL_graph_classification/load_net.py
@@ -0,0 +1,31 @@
+"""
+ Utility file to select GraphNN model as
+ selected by the user
+"""
+
+from nets.OGBMOL_graph_classification.gatedgcn_net import GatedGCNNet
+from nets.OGBMOL_graph_classification.pna_net import PNANet
+from nets.OGBMOL_graph_classification.san_net import SANNet
+from nets.OGBMOL_graph_classification.graphit_net import GraphiTNet
+
+def GatedGCN(net_params):
+ return GatedGCNNet(net_params)
+
+def PNA(net_params):
+ return PNANet(net_params)
+
+def SAN(net_params):
+ return SANNet(net_params)
+
+def GraphiT(net_params):
+ return GraphiTNet(net_params)
+
+def gnn_model(MODEL_NAME, net_params):
+ models = {
+ 'GatedGCN': GatedGCN,
+ 'PNA': PNA,
+ 'SAN': SAN,
+ 'GraphiT': GraphiT
+ }
+
+ return models[MODEL_NAME](net_params)
\ No newline at end of file
diff --git a/nets/OGBMOL_graph_classification/pna_net.py b/nets/OGBMOL_graph_classification/pna_net.py
new file mode 100644
index 0000000..00bf82c
--- /dev/null
+++ b/nets/OGBMOL_graph_classification/pna_net.py
@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+
+"""
+ PNA-LSPE
+
+"""
+
+from layers.pna_layer import PNANoTowersLayer as PNALayer
+from layers.pna_lspe_layer import PNANoTowersLSPELayer as PNALSPELayer
+from layers.mlp_readout_layer import MLPReadout2 as MLPReadout
+
+
+class PNANet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+
+ hidden_dim = net_params['hidden_dim']
+ out_dim = net_params['out_dim']
+ n_classes = net_params['n_classes']
+ dropout = net_params['dropout']
+ self.dropout_2 = net_params['dropout_2']
+
+ n_layers = net_params['L']
+ self.readout = net_params['readout']
+ self.graph_norm = net_params['graph_norm']
+ self.batch_norm = net_params['batch_norm']
+ self.aggregators = net_params['aggregators']
+ self.scalers = net_params['scalers']
+ self.avg_d = net_params['avg_d']
+ self.residual = net_params['residual']
+ self.edge_feat = net_params['edge_feat']
+ edge_dim = net_params['edge_dim']
+ pretrans_layers = net_params['pretrans_layers']
+ posttrans_layers = net_params['posttrans_layers']
+ self.gru_enable = net_params['gru']
+ device = net_params['device']
+ self.device = device
+ self.pe_init = net_params['pe_init']
+ self.n_classes = net_params['n_classes']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim)
+
+ self.embedding_h = AtomEncoder(emb_dim=hidden_dim)
+
+ if self.edge_feat:
+ self.embedding_e = BondEncoder(emb_dim=edge_dim)
+
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList(
+ [PNALSPELayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm,
+ batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, residual=self.residual,
+ edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss)
+ for _ in range(n_layers - 1)])
+ self.layers.append(PNALSPELayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm,
+ batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers,
+ residual=self.residual, edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss))
+
+ else:
+ # NoPE
+ self.layers = nn.ModuleList(
+ [PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm,
+ batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, residual=self.residual,
+ edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss)
+ for _ in range(n_layers - 1)])
+ self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm,
+ batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers,
+ residual=self.residual, edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss))
+
+ if self.gru_enable:
+ self.gru = GRU(hidden_dim, hidden_dim, device)
+
+ self.MLP_layer = MLPReadout(out_dim, n_classes, self.dropout_2)
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = self.embedding_h(h)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ if self.edge_feat:
+ e = self.embedding_e(e)
+
+ for i, conv in enumerate(self.layers):
+ h_t, p_t = conv(g, h, p, e, snorm_n)
+ if self.gru_enable and i != len(self.layers) - 1:
+ h_t = self.gru(h, h_t)
+ h, p = h_t, p_t
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = F.dropout(p, self.dropout_2, training=self.training)
+ p = self.p_out(p)
+ g.ndata['p'] = p
+
+ if self.use_lapeig_loss:
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ if self.pe_init == 'rand_walk':
+ # Concat h and p
+ hp = torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)
+ hp = F.dropout(hp, self.dropout_2, training=self.training)
+ hp = self.Whp(hp)
+ g.ndata['h'] = hp
+
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ if self.n_classes == 128:
+ return_g = None # not passing PCBA graphs due to memory
+ else:
+ return_g = g
+
+ return self.MLP_layer(hg), return_g
+
+ def loss(self, pred, labels):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels)
+
+ if self.use_lapeig_loss:
+ # Loss B: Laplacian Eigenvector Loss --------------------------------------------
+ g = self.g
+ n = g.number_of_nodes()
+
+ # Laplacian
+ A = g.adjacency_matrix(scipy_fmt="csr")
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(n) - N * A * N
+
+ p = g.ndata['p']
+ pT = torch.transpose(p, 1, 0)
+ loss_b_1 = torch.trace(torch.mm(torch.mm(pT, torch.Tensor(L.todense()).to(self.device)), p))
+
+ # Correct batch-graph wise loss_b_2 implementation; using a block diagonal matrix
+ bg = dgl.unbatch(g)
+ batch_size = len(bg)
+ P = sp.block_diag([bg[i].ndata['p'].detach().cpu() for i in range(batch_size)])
+ PTP_In = P.T * P - sp.eye(P.shape[1])
+ loss_b_2 = torch.tensor(norm(PTP_In, 'fro')**2).float().to(self.device)
+
+ loss_b = ( loss_b_1 + self.lambda_loss * loss_b_2 ) / ( self.pos_enc_dim * batch_size * n)
+
+ del bg, P, PTP_In, loss_b_1, loss_b_2
+
+ loss = loss_a + self.alpha_loss * loss_b
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/OGBMOL_graph_classification/san_net.py b/nets/OGBMOL_graph_classification/san_net.py
new file mode 100644
index 0000000..62f9791
--- /dev/null
+++ b/nets/OGBMOL_graph_classification/san_net.py
@@ -0,0 +1,141 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+
+"""
+ SAN-GT and SAN-GT-LSPE
+
+"""
+
+from layers.san_gt_layer import SAN_GT_Layer
+from layers.san_gt_lspe_layer import SAN_GT_LSPE_Layer
+from layers.mlp_readout_layer import MLPReadout
+
+
+class SANNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+
+ full_graph = net_params['full_graph']
+ init_gamma = net_params['init_gamma']
+
+ # learn gamma
+ self.gamma = nn.Parameter(torch.FloatTensor([init_gamma]))
+
+ GT_layers = net_params['L']
+ GT_hidden_dim = net_params['hidden_dim']
+ GT_out_dim = net_params['out_dim']
+ GT_n_heads = net_params['n_heads']
+
+ self.residual = net_params['residual']
+ self.readout = net_params['readout']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+
+ self.readout = net_params['readout']
+ self.layer_norm = net_params['layer_norm']
+ self.batch_norm = net_params['batch_norm']
+
+ n_classes = net_params['n_classes']
+ self.device = net_params['device']
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim)
+
+ self.embedding_h = AtomEncoder(GT_hidden_dim)
+ self.embedding_e = BondEncoder(GT_hidden_dim)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ])
+ self.layers.append(SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual))
+ else:
+ # NoPE
+ self.layers = nn.ModuleList([ SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ])
+ self.layers.append(SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual))
+
+ self.MLP_layer = MLPReadout(GT_out_dim, n_classes)
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ h = self.embedding_h(h)
+ e = self.embedding_e(e)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ for conv in self.layers:
+ h, p = conv(g, h, p, e, snorm_n)
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = self.p_out(p)
+ g.ndata['p'] = p
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, pred, labels):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/ZINC_graph_regression/gatedgcn_net.py b/nets/ZINC_graph_regression/gatedgcn_net.py
new file mode 100644
index 0000000..d3150a1
--- /dev/null
+++ b/nets/ZINC_graph_regression/gatedgcn_net.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+"""
+ GatedGCN and GatedGCN-LSPE
+
+"""
+from layers.gatedgcn_layer import GatedGCNLayer
+from layers.gatedgcn_lspe_layer import GatedGCNLSPELayer
+from layers.mlp_readout_layer import MLPReadout
+
+class GatedGCNNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+ num_atom_type = net_params['num_atom_type']
+ num_bond_type = net_params['num_bond_type']
+ hidden_dim = net_params['hidden_dim']
+ out_dim = net_params['out_dim']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+ self.n_layers = net_params['L']
+ self.readout = net_params['readout']
+ self.batch_norm = net_params['batch_norm']
+ self.residual = net_params['residual']
+ self.edge_feat = net_params['edge_feat']
+ self.device = net_params['device']
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk', 'lap_pe']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim)
+
+ self.embedding_h = nn.Embedding(num_atom_type, hidden_dim)
+
+ if self.edge_feat:
+ self.embedding_e = nn.Embedding(num_bond_type, hidden_dim)
+ else:
+ self.embedding_e = nn.Linear(1, hidden_dim)
+
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ GatedGCNLSPELayer(hidden_dim, hidden_dim, dropout,
+ self.batch_norm, residual=self.residual) for _ in range(self.n_layers-1) ])
+ self.layers.append(GatedGCNLSPELayer(hidden_dim, out_dim, dropout, self.batch_norm, residual=self.residual))
+ else:
+ # NoPE or LapPE
+ self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout,
+ self.batch_norm, residual=self.residual, graph_norm=False) for _ in range(self.n_layers-1) ])
+ self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, residual=self.residual, graph_norm=False))
+
+ self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ # input embedding
+ h = self.embedding_h(h)
+ h = self.in_feat_dropout(h)
+
+ if self.pe_init in ['rand_walk', 'lap_pe']:
+ p = self.embedding_p(p)
+
+ if self.pe_init == 'lap_pe':
+ h = h + p
+ p = None
+
+ if not self.edge_feat: # edge feature set to 1
+ e = torch.ones(e.size(0),1).to(self.device)
+ e = self.embedding_e(e)
+
+
+ # convnets
+ for conv in self.layers:
+ h, p, e = conv(g, h, p, e, snorm_n)
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ p = self.p_out(p)
+ g.ndata['p'] = p
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ # readout
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, scores, targets):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = nn.L1Loss()(scores, targets)
+
+ if self.use_lapeig_loss:
+ # Loss B: Laplacian Eigenvector Loss --------------------------------------------
+ g = self.g
+ n = g.number_of_nodes()
+
+ # Laplacian
+ A = g.adjacency_matrix(scipy_fmt="csr")
+ N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
+ L = sp.eye(n) - N * A * N
+
+ p = g.ndata['p']
+ pT = torch.transpose(p, 1, 0)
+ loss_b_1 = torch.trace(torch.mm(torch.mm(pT, torch.Tensor(L.todense()).to(self.device)), p))
+
+ # Correct batch-graph wise loss_b_2 implementation; using a block diagonal matrix
+ bg = dgl.unbatch(g)
+ batch_size = len(bg)
+ P = sp.block_diag([bg[i].ndata['p'].detach().cpu() for i in range(batch_size)])
+ PTP_In = P.T * P - sp.eye(P.shape[1])
+ loss_b_2 = torch.tensor(norm(PTP_In, 'fro')**2).float().to(self.device)
+
+ loss_b = ( loss_b_1 + self.lambda_loss * loss_b_2 ) / ( self.pos_enc_dim * batch_size * n)
+
+ del bg, P, PTP_In, loss_b_1, loss_b_2
+
+ loss = loss_a + self.alpha_loss * loss_b
+ else:
+ loss = loss_a
+
+ return loss
+
+
+
+
+
\ No newline at end of file
diff --git a/nets/ZINC_graph_regression/graphit_net.py b/nets/ZINC_graph_regression/graphit_net.py
new file mode 100644
index 0000000..48b28f1
--- /dev/null
+++ b/nets/ZINC_graph_regression/graphit_net.py
@@ -0,0 +1,145 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import numpy as np
+
+from scipy import sparse as sp
+
+"""
+ GraphiT-GT and GraphiT-GT-LSPE
+
+"""
+from layers.graphit_gt_layer import GraphiT_GT_Layer
+from layers.graphit_gt_lspe_layer import GraphiT_GT_LSPE_Layer
+from layers.mlp_readout_layer import MLPReadout
+
+class GraphiTNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+
+ num_atom_type = net_params['num_atom_type']
+ num_bond_type = net_params['num_bond_type']
+
+ full_graph = net_params['full_graph']
+ gamma = net_params['gamma']
+ self.adaptive_edge_PE = net_params['adaptive_edge_PE']
+
+ GT_layers = net_params['L']
+ GT_hidden_dim = net_params['hidden_dim']
+ GT_out_dim = net_params['out_dim']
+ GT_n_heads = net_params['n_heads']
+
+ self.residual = net_params['residual']
+ self.readout = net_params['readout']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+
+ self.readout = net_params['readout']
+ self.layer_norm = net_params['layer_norm']
+ self.batch_norm = net_params['batch_norm']
+
+ self.device = net_params['device']
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim)
+
+ self.embedding_h = nn.Embedding(num_atom_type, GT_hidden_dim)
+ self.embedding_e = nn.Embedding(num_bond_type, GT_hidden_dim)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, dropout,
+ self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ])
+ self.layers.append(GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, dropout,
+ self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE))
+ else:
+ # NoPE
+ self.layers = nn.ModuleList([ GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, dropout,
+ self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ])
+ self.layers.append(GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, dropout,
+ self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE))
+
+ self.MLP_layer = MLPReadout(GT_out_dim, 1) # 1 out dim since regression problem
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ # input embedding
+ h = self.embedding_h(h)
+ e = self.embedding_e(e)
+
+ h = self.in_feat_dropout(h)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ # GNN
+ for conv in self.layers:
+ h, p = conv(g, h, p, e, snorm_n)
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = self.p_out(p)
+ g.ndata['p'] = p
+
+ if self.use_lapeig_loss and self.pe_init == 'rand_walk':
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ if self.pe_init == 'rand_walk':
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ # readout
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, scores, targets):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = nn.L1Loss()(scores, targets)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/ZINC_graph_regression/load_net.py b/nets/ZINC_graph_regression/load_net.py
new file mode 100644
index 0000000..03dd401
--- /dev/null
+++ b/nets/ZINC_graph_regression/load_net.py
@@ -0,0 +1,31 @@
+"""
+ Utility file to select GraphNN model as
+ selected by the user
+"""
+
+from nets.ZINC_graph_regression.gatedgcn_net import GatedGCNNet
+from nets.ZINC_graph_regression.pna_net import PNANet
+from nets.ZINC_graph_regression.san_net import SANNet
+from nets.ZINC_graph_regression.graphit_net import GraphiTNet
+
+def GatedGCN(net_params):
+ return GatedGCNNet(net_params)
+
+def PNA(net_params):
+ return PNANet(net_params)
+
+def SAN(net_params):
+ return SANNet(net_params)
+
+def GraphiT(net_params):
+ return GraphiTNet(net_params)
+
+def gnn_model(MODEL_NAME, net_params):
+ models = {
+ 'GatedGCN': GatedGCN,
+ 'PNA': PNA,
+ 'SAN': SAN,
+ 'GraphiT': GraphiT
+ }
+
+ return models[MODEL_NAME](net_params)
\ No newline at end of file
diff --git a/nets/ZINC_graph_regression/pna_net.py b/nets/ZINC_graph_regression/pna_net.py
new file mode 100644
index 0000000..8ca7f2b
--- /dev/null
+++ b/nets/ZINC_graph_regression/pna_net.py
@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import dgl
+from scipy import sparse as sp
+from scipy.sparse.linalg import norm
+
+
+"""
+ PNA and PNA-LSPE
+
+"""
+
+from layers.pna_layer import PNALayer
+from layers.pna_lspe_layer import PNALSPELayer
+from layers.pna_utils import GRU
+from layers.mlp_readout_layer import MLPReadout
+
+class PNANet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+ num_atom_type = net_params['num_atom_type']
+ num_bond_type = net_params['num_bond_type']
+ hidden_dim = net_params['hidden_dim']
+ out_dim = net_params['out_dim']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+ n_layers = net_params['L']
+ self.readout = net_params['readout']
+ self.graph_norm = net_params['graph_norm']
+ self.batch_norm = net_params['batch_norm']
+ self.residual = net_params['residual']
+ self.aggregators = net_params['aggregators']
+ self.scalers = net_params['scalers']
+ self.avg_d = net_params['avg_d']
+ self.towers = net_params['towers']
+ self.divide_input_first = net_params['divide_input_first']
+ self.divide_input_last = net_params['divide_input_last']
+ self.edge_feat = net_params['edge_feat']
+ edge_dim = net_params['edge_dim']
+ pretrans_layers = net_params['pretrans_layers']
+ posttrans_layers = net_params['posttrans_layers']
+ self.gru_enable = net_params['gru']
+ device = net_params['device']
+ self.device = device
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim)
+
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+
+ self.embedding_h = nn.Embedding(num_atom_type, hidden_dim)
+
+ if self.edge_feat:
+ self.embedding_e = nn.Embedding(num_bond_type, edge_dim)
+
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([PNALSPELayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout,
+ graph_norm=self.graph_norm, batch_norm=self.batch_norm,
+ residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
+ avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat,
+ edge_dim=edge_dim, divide_input=self.divide_input_first,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _
+ in range(n_layers - 1)])
+ self.layers.append(PNALSPELayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout,
+ graph_norm=self.graph_norm, batch_norm=self.batch_norm,
+ residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
+ avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last,
+ edge_features=self.edge_feat, edge_dim=edge_dim,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers))
+ else:
+ # NoPE
+ self.layers = nn.ModuleList([PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout,
+ graph_norm=self.graph_norm, batch_norm=self.batch_norm,
+ residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
+ avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat,
+ edge_dim=edge_dim, divide_input=self.divide_input_first,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _
+ in range(n_layers - 1)])
+ self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout,
+ graph_norm=self.graph_norm, batch_norm=self.batch_norm,
+ residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
+ avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last,
+ edge_features=self.edge_feat, edge_dim=edge_dim,
+ pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers))
+
+ if self.gru_enable:
+ self.gru = GRU(hidden_dim, hidden_dim, device)
+
+ self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+ def forward(self, g, h, p, e, snorm_n):
+ h = self.embedding_h(h)
+ h = self.in_feat_dropout(h)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ if self.edge_feat:
+ e = self.embedding_e(e)
+
+ for i, conv in enumerate(self.layers):
+ h_t, p_t = conv(g, h, p, e, snorm_n)
+ if self.gru_enable and i != len(self.layers) - 1:
+ h_t = self.gru(h, h_t)
+ h, p = h_t, p_t
+
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ p = self.p_out(p)
+ g.ndata['p'] = p
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, scores, targets):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = nn.L1Loss()(scores, targets)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/nets/ZINC_graph_regression/san_net.py b/nets/ZINC_graph_regression/san_net.py
new file mode 100644
index 0000000..0951e49
--- /dev/null
+++ b/nets/ZINC_graph_regression/san_net.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import dgl
+import numpy as np
+
+from scipy import sparse as sp
+
+"""
+ SAN-GT and SAN-GT-LSPE
+
+"""
+from layers.san_gt_layer import SAN_GT_Layer
+from layers.san_gt_lspe_layer import SAN_GT_LSPE_Layer
+from layers.mlp_readout_layer import MLPReadout
+
+class SANNet(nn.Module):
+ def __init__(self, net_params):
+ super().__init__()
+
+ num_atom_type = net_params['num_atom_type']
+ num_bond_type = net_params['num_bond_type']
+
+ full_graph = net_params['full_graph']
+ init_gamma = net_params['init_gamma']
+
+ # learn gamma
+ self.gamma = nn.Parameter(torch.FloatTensor([init_gamma]))
+
+ GT_layers = net_params['L']
+ GT_hidden_dim = net_params['hidden_dim']
+ GT_out_dim = net_params['out_dim']
+ GT_n_heads = net_params['n_heads']
+
+ self.residual = net_params['residual']
+ self.readout = net_params['readout']
+ in_feat_dropout = net_params['in_feat_dropout']
+ dropout = net_params['dropout']
+
+ self.readout = net_params['readout']
+ self.layer_norm = net_params['layer_norm']
+ self.batch_norm = net_params['batch_norm']
+
+ self.device = net_params['device']
+ self.in_feat_dropout = nn.Dropout(in_feat_dropout)
+ self.pe_init = net_params['pe_init']
+
+ self.use_lapeig_loss = net_params['use_lapeig_loss']
+ self.lambda_loss = net_params['lambda_loss']
+ self.alpha_loss = net_params['alpha_loss']
+
+ self.pos_enc_dim = net_params['pos_enc_dim']
+
+ if self.pe_init in ['rand_walk']:
+ self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim)
+
+ self.embedding_h = nn.Embedding(num_atom_type, GT_hidden_dim)
+ self.embedding_e = nn.Embedding(num_bond_type, GT_hidden_dim)
+
+ if self.pe_init == 'rand_walk':
+ # LSPE
+ self.layers = nn.ModuleList([ SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ])
+ self.layers.append(SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual))
+ else:
+ # NoPE
+ self.layers = nn.ModuleList([ SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ])
+ self.layers.append(SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph,
+ dropout, self.layer_norm, self.batch_norm, self.residual))
+
+ self.MLP_layer = MLPReadout(GT_out_dim, 1) # 1 out dim since regression problem
+
+ if self.pe_init == 'rand_walk':
+ self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim)
+ self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim)
+
+ self.g = None # For util; To be accessed in loss() function
+
+
+ def forward(self, g, h, p, e, snorm_n):
+
+ # input embedding
+ h = self.embedding_h(h)
+ e = self.embedding_e(e)
+
+ h = self.in_feat_dropout(h)
+
+ if self.pe_init in ['rand_walk']:
+ p = self.embedding_p(p)
+
+ # GNN
+ for conv in self.layers:
+ h, p = conv(g, h, p, e, snorm_n)
+ g.ndata['h'] = h
+
+ if self.pe_init == 'rand_walk':
+ p = self.p_out(p)
+ g.ndata['p'] = p
+
+ if self.use_lapeig_loss and self.pe_init == 'rand_walk':
+ # Implementing p_g = p_g - torch.mean(p_g, dim=0)
+ means = dgl.mean_nodes(g, 'p')
+ batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p - batch_wise_p_means
+
+ # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0)
+ g.ndata['p'] = p
+ g.ndata['p2'] = g.ndata['p']**2
+ norms = dgl.sum_nodes(g, 'p2')
+ norms = torch.sqrt(norms+1e-6)
+ batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0)
+ p = p / batch_wise_p_l2_norms
+ g.ndata['p'] = p
+
+ if self.pe_init == 'rand_walk':
+ # Concat h and p
+ hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1))
+ g.ndata['h'] = hp
+
+ # readout
+ if self.readout == "sum":
+ hg = dgl.sum_nodes(g, 'h')
+ elif self.readout == "max":
+ hg = dgl.max_nodes(g, 'h')
+ elif self.readout == "mean":
+ hg = dgl.mean_nodes(g, 'h')
+ else:
+ hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
+
+ self.g = g # For util; To be accessed in loss() function
+
+ return self.MLP_layer(hg), g
+
+ def loss(self, scores, targets):
+
+ # Loss A: Task loss -------------------------------------------------------------
+ loss_a = nn.L1Loss()(scores, targets)
+
+ if self.use_lapeig_loss:
+ raise NotImplementedError
+ else:
+ loss = loss_a
+
+ return loss
\ No newline at end of file
diff --git a/scripts/OGBMOL/script_MOLPCBA_all.sh b/scripts/OGBMOL/script_MOLPCBA_all.sh
new file mode 100644
index 0000000..d9231f2
--- /dev/null
+++ b/scripts/OGBMOL/script_MOLPCBA_all.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+
+############
+# Usage
+############
+
+# bash script_MOLPCBA_all.sh
+
+
+####################################
+# MOLPCBA - 4 SEED RUNS OF EACH EXPTS
+####################################
+
+seed0=41
+seed1=95
+seed2=12
+seed3=35
+code=main_OGBMOL_graph_classification.py
+dataset=OGBG-MOLPCBA
+tmux new -s gnn_lspe_PCBA -d
+tmux send-keys "source ~/.bashrc" C-m
+tmux send-keys "source activate gnn_lspe" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLPCBA_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLPCBA_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLPCBA_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLPCBA_LSPE.json' &
+wait" C-m
+tmux send-keys "tmux kill-session -t gnn_lspe_PCBA" C-m
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scripts/OGBMOL/script_MOLTOX21_all.sh b/scripts/OGBMOL/script_MOLTOX21_all.sh
new file mode 100644
index 0000000..4aa698e
--- /dev/null
+++ b/scripts/OGBMOL/script_MOLTOX21_all.sh
@@ -0,0 +1,95 @@
+#!/bin/bash
+
+
+############
+# Usage
+############
+
+# bash script_MOLTOX21_all.sh
+
+
+####################################
+# MOLTOX21 - 4 SEED RUNS OF EACH EXPTS
+####################################
+
+seed0=41
+seed1=95
+seed2=12
+seed3=35
+code=main_OGBMOL_graph_classification.py
+dataset=OGBG-MOLTOX21
+tmux new -s gnn_lspe_TOX21 -d
+tmux send-keys "source ~/.bashrc" C-m
+tmux send-keys "source activate gnn_lspe" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_MOLTOX21_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_MOLTOX21_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_MOLTOX21_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_MOLTOX21_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_MOLTOX21_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_MOLTOX21_LSPE.json' &
+wait" C-m
+tmux send-keys "tmux kill-session -t gnn_lspe_TOX21" C-m
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb b/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb
new file mode 100644
index 0000000..714efbf
--- /dev/null
+++ b/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb
@@ -0,0 +1,176 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_dir = \"../../out/PNA_MOLTOX21_NoPE/results/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file = []\n",
+ "dataset = []\n",
+ "model = []\n",
+ "layer = []\n",
+ "params = []\n",
+ "acc_test = []\n",
+ "acc_train = []\n",
+ "convergence = []\n",
+ "total_time = []\n",
+ "epoch_time = []\n",
+ "\n",
+ "for filename in os.listdir(f_dir):\n",
+ "\n",
+ " if filename[-4:] == \".txt\":\n",
+ " file.append( filename )\n",
+ " \n",
+ " with open(os.path.join(f_dir, filename), \"r\") as f:\n",
+ " lines = f.readlines()\n",
+ "\n",
+ " for line in lines:\n",
+ " # print('h1c',line)\n",
+ "\n",
+ " if line[:9] == \"Dataset: \":\n",
+ " dataset.append( line[9:-2] )\n",
+ "\n",
+ " if line[:7] == \"Model: \":\n",
+ " model.append( line[7:-1] )\n",
+ "\n",
+ " if line[:17] == \"net_params={'L': \":\n",
+ " layer.append( line[17:18] )\n",
+ " \n",
+ " if line[:56] == \"net_params={'full_graph': True, 'init_gamma': 0.1, 'L': \":\n",
+ " layer.append( line[56:58] )\n",
+ " \n",
+ " if line[:37] == \"net_params={'full_graph': True, 'L': \":\n",
+ " layer.append( line[37:39] )\n",
+ " \n",
+ " if line[:18] == \"Total Parameters: \":\n",
+ " params.append( line[18:-1] )\n",
+ "\n",
+ " if line[:10] == \"TEST AUC: \":\n",
+ " acc_test.append( float(line[10:-1]) )\n",
+ " \n",
+ " if line[:11] == \"TRAIN AUC: \":\n",
+ " acc_train.append( float(line[11:-1]) )\n",
+ " \n",
+ " if line[:35] == \" Convergence Time (Epochs): \":\n",
+ " convergence.append( float(line[35:-1]) )\n",
+ "\n",
+ " if line[:18] == \"Total Time Taken: \":\n",
+ " total_time.append( float(line[18:-4]) )\n",
+ "\n",
+ " if line[:24] == 'Average Time Per Epoch: ':\n",
+ " epoch_time.append( float(line[24:-2]) )\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ "# print('file',file)\n",
+ "# print('dataset',dataset)\n",
+ "# print('model',model)\n",
+ "# print('layer',layer)\n",
+ "# print('params',params)\n",
+ "# print('acc_test',acc_test)\n",
+ "# print('acc_train',acc_train)\n",
+ "# print('convergence',convergence)\n",
+ "# print('total_time',total_time)\n",
+ "# print('epoch_time',epoch_time)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "list_datasets = ['ogbg-moltox21']\n",
+ "#print('list_datasets',list_datasets)\n",
+ "\n",
+ "list_gnns = ['GatedGCN', 'PNA', 'SAN', 'GraphiT']\n",
+ "#print('list_gnns',list_gnns)\n",
+ "\n",
+ "\n",
+ " \n",
+ "for data in list_datasets:\n",
+ " #print(data)\n",
+ "\n",
+ " for gnn in list_gnns:\n",
+ " #print('gnn:',gnn)\n",
+ "\n",
+ " acc_test_one_gnn = []\n",
+ " acc_train_one_gnn = []\n",
+ " convergence_one_gnn = []\n",
+ " total_time_one_gnn = []\n",
+ " epoch_time_one_gnn = []\n",
+ " nb_seeds = 0\n",
+ "\n",
+ " for i in range(len(file)):\n",
+ " #print(params[i])\n",
+ " \n",
+ " if data==dataset[i] and gnn==model[i]:\n",
+ " params_one_gnn = params[i]\n",
+ " acc_test_one_gnn.append(acc_test[i])\n",
+ " acc_train_one_gnn.append(acc_train[i])\n",
+ " convergence_one_gnn.append(convergence[i])\n",
+ " total_time_one_gnn.append(total_time[i])\n",
+ " epoch_time_one_gnn.append(epoch_time[i])\n",
+ " L = layer[i]\n",
+ " nb_seeds = nb_seeds + 1\n",
+ "\n",
+ " #print(params_one_gnn)\n",
+ " if len(acc_test_one_gnn)>0:\n",
+ " print(acc_test_one_gnn)\n",
+ " latex_str = f\"{data} & {nb_seeds} & {gnn} & {L} & {params_one_gnn} & {np.mean(acc_test_one_gnn):.3f}$\\pm${np.std(acc_test_one_gnn):.3f} & {np.mean(acc_train_one_gnn):.3f}$\\pm${np.std(acc_train_one_gnn):.3f} & {np.mean(convergence_one_gnn):.2f} & {np.mean(epoch_time_one_gnn):.2f}s/{np.mean(total_time_one_gnn):.2f}hr\"\n",
+ " print(\"\\nDataset & #Seeds & Model & L & Param & Acc_test & Acc_train & Speed & Epoch/Time\\n{}\".format(latex_str,nb_seeds))\n",
+ "\n",
+ " \n",
+ "\n",
+ "print(\"\\n\")\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/scripts/StatisticalResults/generate_statistics_ZINC.ipynb b/scripts/StatisticalResults/generate_statistics_ZINC.ipynb
new file mode 100644
index 0000000..01ff269
--- /dev/null
+++ b/scripts/StatisticalResults/generate_statistics_ZINC.ipynb
@@ -0,0 +1,174 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_dir = \"../../out/GraphiT_ZINC_LSPE_noLapEigLoss/results/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file = []\n",
+ "dataset = []\n",
+ "model = []\n",
+ "layer = []\n",
+ "params = []\n",
+ "acc_test = []\n",
+ "acc_train = []\n",
+ "convergence = []\n",
+ "total_time = []\n",
+ "epoch_time = []\n",
+ "\n",
+ "for filename in os.listdir(f_dir):\n",
+ "\n",
+ " if filename[-4:] == \".txt\":\n",
+ " file.append( filename )\n",
+ " \n",
+ " with open(os.path.join(f_dir, filename), \"r\") as f:\n",
+ " lines = f.readlines()\n",
+ "\n",
+ " for line in lines:\n",
+ " #print('h1c',line)\n",
+ "\n",
+ " if line[:9] == \"Dataset: \":\n",
+ " dataset.append( line[9:-2] )\n",
+ "\n",
+ " if line[:7] == \"Model: \":\n",
+ " model.append( line[7:-1] )\n",
+ "\n",
+ " if line[:17] == \"net_params={'L': \":\n",
+ " layer.append( line[17:18] )\n",
+ " \n",
+ " if line[:56] == \"net_params={'full_graph': True, 'init_gamma': 0.1, 'L': \":\n",
+ " layer.append( line[56:58])\n",
+ " \n",
+ " if line[:37] == \"net_params={'full_graph': True, 'L': \":\n",
+ " layer.append( line[37:39])\n",
+ " \n",
+ " if line[:18] == \"Total Parameters: \":\n",
+ " params.append( line[18:-1] )\n",
+ "\n",
+ " if line[:10] == \"TEST MAE: \":\n",
+ " acc_test.append( float(line[10:-1]) )\n",
+ " \n",
+ " if line[:11] == \"TRAIN MAE: \":\n",
+ " acc_train.append( float(line[11:-1]) )\n",
+ " \n",
+ " if line[4:31] == \"Convergence Time (Epochs): \":\n",
+ " convergence.append( float(line[31:-1]) )\n",
+ "\n",
+ " if line[:18] == \"Total Time Taken: \":\n",
+ " total_time.append( float(line[18:-4]) )\n",
+ "\n",
+ " if line[:24] == 'Average Time Per Epoch: ':\n",
+ " epoch_time.append( float(line[24:-2]) )\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ "# print('file',file)\n",
+ "# print('dataset',dataset)\n",
+ "# print('model',model)\n",
+ "# print('layer',layer)\n",
+ "# print('params',params)\n",
+ "# print('acc_test',acc_test)\n",
+ "# print('acc_train',acc_train)\n",
+ "# print('convergence',convergence)\n",
+ "# print('total_time',total_time)\n",
+ "# print('epoch_time',epoch_time)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "list_datasets = ['ZINC']\n",
+ "#print('list_datasets',list_datasets)\n",
+ "\n",
+ "list_gnns = ['GatedGCN', 'PNA', 'SAN', 'GraphiT']\n",
+ "#print('list_gnns',list_gnns)\n",
+ "\n",
+ "\n",
+ "for data in list_datasets:\n",
+ " #print(data)\n",
+ "\n",
+ " for gnn in list_gnns:\n",
+ " #print('gnn:',gnn)\n",
+ "\n",
+ " acc_test_one_gnn = []\n",
+ " acc_train_one_gnn = []\n",
+ " convergence_one_gnn = []\n",
+ " total_time_one_gnn = []\n",
+ " epoch_time_one_gnn = []\n",
+ " nb_seeds = 0\n",
+ "\n",
+ " for i in range(len(file)):\n",
+ " #print(params[i])\n",
+ " if data==dataset[i] and gnn==model[i]:\n",
+ " params_one_gnn = params[i]\n",
+ " acc_test_one_gnn.append(acc_test[i])\n",
+ " acc_train_one_gnn.append(acc_train[i])\n",
+ " convergence_one_gnn.append(convergence[i])\n",
+ " total_time_one_gnn.append(total_time[i])\n",
+ " epoch_time_one_gnn.append(epoch_time[i])\n",
+ " L = layer[i]\n",
+ " nb_seeds = nb_seeds + 1\n",
+ "\n",
+ " #print(params_one_gnn)\n",
+ " if len(acc_test_one_gnn)>0:\n",
+ " print(acc_test_one_gnn)\n",
+ " latex_str = f\"{data} & {nb_seeds} & {gnn} & {L} & {params_one_gnn} & {np.mean(acc_test_one_gnn):.3f}$\\pm${np.std(acc_test_one_gnn):.3f} & {np.mean(acc_train_one_gnn):.3f}$\\pm${np.std(acc_train_one_gnn):.3f} & {np.mean(convergence_one_gnn):.2f} & {np.mean(epoch_time_one_gnn):.2f}s/{np.mean(total_time_one_gnn):.2f}hr\"\n",
+ " print(\"\\nDataset & #Seeds & Model & L & Param & Acc_test & Acc_train & Speed & Epoch/Time\\n{}\".format(latex_str,nb_seeds))\n",
+ "\n",
+ " \n",
+ "\n",
+ "print(\"\\n\")\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/scripts/TensorBoard/script_tensorboard.sh b/scripts/TensorBoard/script_tensorboard.sh
new file mode 100644
index 0000000..81ddd33
--- /dev/null
+++ b/scripts/TensorBoard/script_tensorboard.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+
+# bash script_tensorboard.sh
+
+
+
+
+
+tmux new -s tensorboard -d
+tmux send-keys "source activate gnn_lspe" C-m
+tmux send-keys "tensorboard --logdir out/ --port 6006" C-m
+
+
+
+
+
+
+
+
+
diff --git a/scripts/ZINC/script_ZINC_all.sh b/scripts/ZINC/script_ZINC_all.sh
new file mode 100644
index 0000000..30cb7bc
--- /dev/null
+++ b/scripts/ZINC/script_ZINC_all.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+
+
+############
+# Usage
+############
+
+# bash script_ZINC_all.sh
+
+
+####################################
+# ZINC - 4 SEED RUNS OF EACH EXPTS
+####################################
+
+seed0=41
+seed1=95
+seed2=12
+seed3=35
+code=main_ZINC_graph_regression.py
+dataset=ZINC
+tmux new -s gnn_lspe_ZINC -d
+tmux send-keys "source ~/.bashrc" C-m
+tmux send-keys "source activate gnn_lspe" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LapPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LapPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LapPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LapPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_ZINC_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_ZINC_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_ZINC_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_ZINC_LSPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_ZINC_NoPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_ZINC_NoPE.json' &
+wait" C-m
+tmux send-keys "
+python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_ZINC_LSPE.json' &
+python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_ZINC_LSPE.json' &
+wait" C-m
+tmux send-keys "tmux kill-session -t gnn_lspe_ZINC" C-m
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/train/metrics.py b/train/metrics.py
new file mode 100644
index 0000000..b584da0
--- /dev/null
+++ b/train/metrics.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import f1_score
+import numpy as np
+
+
+def MAE(scores, targets):
+ MAE = F.l1_loss(scores, targets)
+ MAE = MAE.detach().item()
+ return MAE
+
+
+def accuracy_TU(scores, targets):
+ scores = scores.detach().argmax(dim=1)
+ acc = (scores==targets).float().sum().item()
+ return acc
+
+
+def accuracy_MNIST_CIFAR(scores, targets):
+ scores = scores.detach().argmax(dim=1)
+ acc = (scores==targets).float().sum().item()
+ return acc
+
+def accuracy_CITATION_GRAPH(scores, targets):
+ scores = scores.detach().argmax(dim=1)
+ acc = (scores==targets).float().sum().item()
+ acc = acc / len(targets)
+ return acc
+
+
+def accuracy_SBM(scores, targets):
+ S = targets.cpu().numpy()
+ C = np.argmax( torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy() , axis=1 )
+ CM = confusion_matrix(S,C).astype(np.float32)
+ nb_classes = CM.shape[0]
+ targets = targets.cpu().detach().numpy()
+ nb_non_empty_classes = 0
+ pr_classes = np.zeros(nb_classes)
+ for r in range(nb_classes):
+ cluster = np.where(targets==r)[0]
+ if cluster.shape[0] != 0:
+ pr_classes[r] = CM[r,r]/ float(cluster.shape[0])
+ if CM[r,r]>0:
+ nb_non_empty_classes += 1
+ else:
+ pr_classes[r] = 0.0
+ acc = 100.* np.sum(pr_classes)/ float(nb_classes)
+ return acc
+
+
+def binary_f1_score(scores, targets):
+ """Computes the F1 score using scikit-learn for binary class labels.
+
+ Returns the F1 score for the positive class, i.e. labelled '1'.
+ """
+ y_true = targets.cpu().numpy()
+ y_pred = scores.argmax(dim=1).cpu().numpy()
+ return f1_score(y_true, y_pred, average='binary')
+
+
+def accuracy_VOC(scores, targets):
+ scores = scores.detach().argmax(dim=1).cpu()
+ targets = targets.cpu().detach().numpy()
+ acc = f1_score(scores, targets, average='weighted')
+ return acc
diff --git a/train/train_OGBMOL_graph_classification.py b/train/train_OGBMOL_graph_classification.py
new file mode 100644
index 0000000..ba74bef
--- /dev/null
+++ b/train/train_OGBMOL_graph_classification.py
@@ -0,0 +1,131 @@
+"""
+ Utility functions for training one epoch
+ and evaluating one epoch
+"""
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+from tqdm import tqdm
+
+import dgl
+
+
+def train_epoch_sparse(model, optimizer, device, data_loader, epoch, evaluator):
+ model.train()
+
+ epoch_loss = 0
+ nb_data = 0
+
+ y_true = []
+ y_pred = []
+
+ for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader):
+ optimizer.zero_grad()
+
+ batch_graphs = batch_graphs.to(device)
+ batch_x = batch_graphs.ndata['feat'].to(device)
+ batch_e = batch_graphs.edata['feat'].to(device)
+ batch_labels = batch_labels.to(device)
+ batch_snorm_n = batch_snorm_n.to(device)
+
+ try:
+ batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
+ except KeyError:
+ batch_pos_enc = None
+
+ if model.pe_init == 'lap_pe':
+ sign_flip = torch.rand(batch_pos_enc.size(1)).to(device)
+ sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
+ batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0)
+
+ batch_pred, __ = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n)
+ del __
+
+ # ignore nan labels (unlabeled) when computing training loss
+ is_labeled = batch_labels == batch_labels
+ loss = model.loss(batch_pred.to(torch.float32)[is_labeled], batch_labels.to(torch.float32)[is_labeled])
+
+ loss.backward()
+ optimizer.step()
+
+ y_true.append(batch_labels.view(batch_pred.shape).detach().cpu())
+ y_pred.append(batch_pred.detach().cpu())
+
+ epoch_loss += loss.detach().item()
+ nb_data += batch_labels.size(0)
+
+ epoch_loss /= (iter + 1)
+
+ y_true = torch.cat(y_true, dim = 0).numpy()
+ y_pred = torch.cat(y_pred, dim = 0).numpy()
+
+ # compute performance metric using OGB evaluator
+ input_dict = {"y_true": y_true, "y_pred": y_pred}
+ perf = evaluator.eval(input_dict)
+
+ if batch_labels.size(1) == 128: # MOLPCBA
+ return_perf = perf['ap']
+ elif batch_labels.size(1) == 12: # MOLTOX21
+ return_perf = perf['rocauc']
+
+ return epoch_loss, return_perf, optimizer
+
+def evaluate_network_sparse(model, device, data_loader, epoch, evaluator):
+ model.eval()
+
+ epoch_loss = 0
+ nb_data = 0
+
+ y_true = []
+ y_pred = []
+
+ out_graphs_for_lapeig_viz = []
+
+ with torch.no_grad():
+ for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader):
+ batch_graphs = batch_graphs.to(device)
+ batch_x = batch_graphs.ndata['feat'].to(device)
+ batch_e = batch_graphs.edata['feat'].to(device)
+ batch_labels = batch_labels.to(device)
+ batch_snorm_n = batch_snorm_n.to(device)
+
+
+ try:
+ batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
+ except KeyError:
+ batch_pos_enc = None
+
+ batch_pred, batch_g = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n)
+
+
+ # ignore nan labels (unlabeled) when computing loss
+ is_labeled = batch_labels == batch_labels
+ loss = model.loss(batch_pred.to(torch.float32)[is_labeled], batch_labels.to(torch.float32)[is_labeled])
+
+ y_true.append(batch_labels.view(batch_pred.shape).detach().cpu())
+ y_pred.append(batch_pred.detach().cpu())
+
+ epoch_loss += loss.detach().item()
+ nb_data += batch_labels.size(0)
+
+ if batch_g is not None:
+ out_graphs_for_lapeig_viz += dgl.unbatch(batch_g)
+ else:
+ out_graphs_for_lapeig_viz = None
+
+ epoch_loss /= (iter + 1)
+
+ y_true = torch.cat(y_true, dim = 0).numpy()
+ y_pred = torch.cat(y_pred, dim = 0).numpy()
+
+ # compute performance metric using OGB evaluator
+ input_dict = {"y_true": y_true, "y_pred": y_pred}
+ perf = evaluator.eval(input_dict)
+
+ if batch_labels.size(1) == 128: # MOLPCBA
+ return_perf = perf['ap']
+ elif batch_labels.size(1) == 12: # MOLTOX21
+ return_perf = perf['rocauc']
+
+ return epoch_loss, return_perf, out_graphs_for_lapeig_viz
\ No newline at end of file
diff --git a/train/train_ZINC_graph_regression.py b/train/train_ZINC_graph_regression.py
new file mode 100644
index 0000000..40c522c
--- /dev/null
+++ b/train/train_ZINC_graph_regression.py
@@ -0,0 +1,84 @@
+"""
+ Utility functions for training one epoch
+ and evaluating one epoch
+"""
+import torch
+import torch.nn as nn
+import math
+
+import dgl
+
+from train.metrics import MAE
+
+
+
+def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
+ model.train()
+ epoch_loss = 0
+ epoch_train_mae = 0
+ nb_data = 0
+ gpu_mem = 0
+ for iter, (batch_graphs, batch_targets, batch_snorm_n) in enumerate(data_loader):
+ batch_graphs = batch_graphs.to(device)
+ batch_x = batch_graphs.ndata['feat'].to(device) # num x feat
+ batch_e = batch_graphs.edata['feat'].to(device)
+ batch_targets = batch_targets.to(device)
+ batch_snorm_n = batch_snorm_n.to(device)
+ optimizer.zero_grad()
+
+ try:
+ batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
+ except KeyError:
+ batch_pos_enc = None
+
+ if model.pe_init == 'lap_pe':
+ sign_flip = torch.rand(batch_pos_enc.size(1)).to(device)
+ sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
+ batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0)
+
+ batch_scores, __ = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n)
+ del __
+
+ loss = model.loss(batch_scores, batch_targets)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.detach().item()
+ epoch_train_mae += MAE(batch_scores, batch_targets)
+ nb_data += batch_targets.size(0)
+ epoch_loss /= (iter + 1)
+ epoch_train_mae /= (iter + 1)
+
+ return epoch_loss, epoch_train_mae, optimizer
+
+def evaluate_network_sparse(model, device, data_loader, epoch):
+ model.eval()
+ epoch_test_loss = 0
+ epoch_test_mae = 0
+ nb_data = 0
+ out_graphs_for_lapeig_viz = []
+ with torch.no_grad():
+ for iter, (batch_graphs, batch_targets, batch_snorm_n) in enumerate(data_loader):
+ batch_graphs = batch_graphs.to(device)
+ batch_x = batch_graphs.ndata['feat'].to(device)
+ batch_e = batch_graphs.edata['feat'].to(device)
+ batch_targets = batch_targets.to(device)
+ batch_snorm_n = batch_snorm_n.to(device)
+
+ try:
+ batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
+ except KeyError:
+ batch_pos_enc = None
+
+ batch_scores, batch_g = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n)
+
+ loss = model.loss(batch_scores, batch_targets)
+ epoch_test_loss += loss.detach().item()
+ epoch_test_mae += MAE(batch_scores, batch_targets)
+ nb_data += batch_targets.size(0)
+
+ out_graphs_for_lapeig_viz += dgl.unbatch(batch_g)
+ epoch_test_loss /= (iter + 1)
+ epoch_test_mae /= (iter + 1)
+
+ return epoch_test_loss, epoch_test_mae, out_graphs_for_lapeig_viz
+
diff --git a/utils/cleaner_main.py b/utils/cleaner_main.py
new file mode 100644
index 0000000..af383c8
--- /dev/null
+++ b/utils/cleaner_main.py
@@ -0,0 +1,102 @@
+
+# Clean the main.py file after conversion from notebook.
+# Any notebook code is removed from the main.py file.
+
+
+import subprocess
+
+
+def cleaner_main(filename):
+
+ # file names
+ file_notebook = filename + '.ipynb'
+ file_python = filename + '.py'
+
+
+ # convert notebook to python file
+ print('Convert ' + file_notebook + ' to ' + file_python)
+ subprocess.check_output('jupyter nbconvert --to script ' + str(file_notebook) , shell=True)
+
+ print('Clean ' + file_python)
+
+ # open file
+ with open(file_python, "r") as f_in:
+ lines_in = f_in.readlines()
+
+ # remove cell indices
+ lines_in = [ line for i,line in enumerate(lines_in) if '# In[' not in line ]
+
+ # remove comments
+ lines_in = [ line for i,line in enumerate(lines_in) if line[0]!='#' ]
+
+ # remove "in_ipynb()" function
+ idx_start_fnc = next((i for i, x in enumerate(lines_in) if 'def in_ipynb' in x), None)
+ if idx_start_fnc!=None:
+ idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:4] not in ['\n',' ']), None)
+ lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ]
+ list_elements_to_remove = ['in_ipynb()', 'print(notebook_mode)']
+ for elem in list_elements_to_remove:
+ lines_in = [ line for i,line in enumerate(lines_in) if elem not in line ]
+
+ # unindent "if notebook_mode==False" block
+ idx_start_fnc = next((i for i, x in enumerate(lines_in) if 'if notebook_mode==False' in x), None)
+ if idx_start_fnc!=None:
+ idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:8] not in ['\n',' ']), None)
+ for i in range(idx_start_fnc,idx_end_fnc+1):
+ lines_in[i] = lines_in[i][4:]
+ lines_in.pop(idx_start_fnc)
+ list_elements_to_remove = ['# notebook mode', '# terminal mode']
+ for elem in list_elements_to_remove:
+ lines_in = [ line for i,line in enumerate(lines_in) if elem not in line ]
+
+ # remove remaining "if notebook_mode==True" blocks - single indent
+ run = True
+ while run:
+ idx_start_fnc = next((i for i, x in enumerate(lines_in) if x[:16]=='if notebook_mode'), None)
+ if idx_start_fnc!=None:
+ idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:4] not in ['\n',' ']), None)
+ lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ]
+ else:
+ run = False
+
+ # remove "if notebook_mode==True" block - double indents
+ idx_start_fnc = next((i for i, x in enumerate(lines_in) if x[:20]==' if notebook_mode'), None)
+ if idx_start_fnc!=None:
+ idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:8] not in ['\n',' ']), None)
+ lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ]
+
+ # prepare main() for terminal mode
+ idx = next((i for i, x in enumerate(lines_in) if 'def main' in x), None)
+ if idx!=None: lines_in[idx] = 'def main():'
+ idx = next((i for i, x in enumerate(lines_in) if x[:5]=='else:'), None)
+ if idx!=None: lines_in.pop(idx)
+ idx = next((i for i, x in enumerate(lines_in) if x[:10]==' main()'), None)
+ if idx!=None: lines_in[idx] = 'main()'
+
+ # remove notebook variables
+ idx = next((i for i, x in enumerate(lines_in) if 'use_gpu = True' in x), None)
+ if idx!=None: lines_in.pop(idx)
+ idx = next((i for i, x in enumerate(lines_in) if 'gpu_id = -1' in x), None)
+ if idx!=None: lines_in.pop(idx)
+ idx = next((i for i, x in enumerate(lines_in) if 'device = None' in x), None)
+ if idx!=None: lines_in.pop(idx)
+ run = True
+ while run:
+ idx = next((i for i, x in enumerate(lines_in) if x[:10]=='MODEL_NAME'), None)
+ if idx!=None:
+ lines_in.pop(idx)
+ else:
+ run = False
+
+ # save clean file
+ lines_out = str()
+ for line in lines_in: lines_out += line
+ with open(file_python, 'w') as f_out:
+ f_out.write(lines_out)
+
+ print('Done. ')
+
+
+
+
+
diff --git a/utils/plot_util.py b/utils/plot_util.py
new file mode 100644
index 0000000..0c20d65
--- /dev/null
+++ b/utils/plot_util.py
@@ -0,0 +1,46 @@
+"""
+ Util function to plot graph with eigenvectors
+ x-axis: first dim
+ y-axis: second dim
+"""
+
+import networkx as nx
+
+def plot_graph_eigvec(plt, g_id, g_dgl, feature_key, actual_eigvecs=False, predicted_eigvecs=False):
+
+ if actual_eigvecs:
+ plt.set_xlabel('first eigenvec')
+ plt.set_ylabel('second eigenvec')
+ else:
+ plt.set_xlabel('first predicted pe')
+ plt.set_ylabel('second predicted pe')
+
+ g_dgl = g_dgl.cpu()
+ g_dgl.ndata['feats'] = g_dgl.ndata[feature_key][:,:2]
+ g_nx = g_dgl.to_networkx(node_attrs=['feats'])
+
+ labels = {}
+ for idx, node in enumerate(g_nx.nodes()):
+ labels[node] = str(idx)
+
+ num_nodes = g_dgl.num_nodes()
+ num_edges = g_dgl.num_edges()
+
+ edge_list = []
+ srcs, dsts = g_dgl.edges()
+ for edge_i in range(num_edges):
+ edge_list.append((srcs[edge_i].item(), dsts[edge_i].item()))
+
+ # fig, ax = plt.subplots()
+ # first 2-dim of eigenvecs are x,y coordinates, and the 3rd dim of eigenvec is plotted as node intensity
+ # intensities = g_dgl.ndata['feats'][:,2]
+ nx.draw_networkx_nodes(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), node_color='r', node_size=180, label=list(range(g_dgl.number_of_nodes())))
+ nx.draw_networkx_edges(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), edge_list, alpha=0.3)
+ nx.draw_networkx_labels(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), labels, font_size=16)
+ plt.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
+
+ title = "Graph ID: " + str(g_id)
+
+ title += " | Actual eigvecs" if actual_eigvecs else " | Predicted PEs"
+ plt.title.set_text(title)
+
\ No newline at end of file
diff --git a/utils/visualize_RWPE_studies.ipynb b/utils/visualize_RWPE_studies.ipynb
new file mode 100644
index 0000000..63e30a6
--- /dev/null
+++ b/utils/visualize_RWPE_studies.ipynb
@@ -0,0 +1,197 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b941618a",
+ "metadata": {},
+ "source": [
+ "## ZINC Visualization with Init PE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "58a892b7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.chdir('../') # go to root folder of the project\n",
+ "print(os.getcwd())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ab99754",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import networkx as nx\n",
+ "from itertools import count\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from itertools import count"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2adb4608",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from data.data import LoadData\n",
+ "zinc_d = LoadData('ZINC')\n",
+ "\n",
+ "pos_enc_dim = 24\n",
+ "zinc_d._init_positional_encodings(pos_enc_dim, 'rand_walk')\n",
+ "zinc_d._add_eig_vecs(36)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6b12e503",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "num_nodes = []\n",
+ "num_unique_RWPEs = []\n",
+ "num_unique_LapPEs = []\n",
+ "\n",
+ "for g_ in zinc_d.val:\n",
+ " num_nodes.append(g_[0].number_of_nodes())\n",
+ " num_unique_RWPEs.append(len(torch.unique(g_[0].ndata['pos_enc'], dim=0)))\n",
+ "for g_ in zinc_d.val:\n",
+ " num_unique_LapPEs.append(len(torch.unique(g_[0].ndata['eigvec'], dim=0)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "edcf8c63",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_init_PE_comparison(num_nodes, num_unique_PEs, PE='RWPE'):\n",
+ " fig = plt.figure(dpi=100, figsize=(7, 6))\n",
+ " ax = plt.axes()\n",
+ " plt.xlabel(\"Number of nodes\", fontsize=10)\n",
+ " if PE == 'RWPE':\n",
+ " plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique RWPEs\", fontsize=15)\n",
+ " plt.ylabel(\"Number of unique RWPEs\", fontsize=10)\n",
+ " elif PE == 'LapPE':\n",
+ " plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique LapPEs\", fontsize=15)\n",
+ " plt.ylabel(\"Number of unique LapPEs\", fontsize=10)\n",
+ " x = np.array(num_nodes)\n",
+ " y1 = np.array(num_unique_PEs)\n",
+ " #plt.xticks(x)\n",
+ " plt.hist2d(x, y1, (50, 50), cmap=plt.cm.Reds)# plt.cm.jet)\n",
+ " plt.colorbar()\n",
+ " plt.xlim([10, 35])\n",
+ " plt.ylim([10, 35])\n",
+ " #plt.scatter(x, y1, marker=\"o\", color=\"green\", linewidth=0.5)\n",
+ " x = np.linspace(9,40,100)\n",
+ " y = x\n",
+ " plt.plot(x, y, '-r', )\n",
+ "\n",
+ " #fig.savefig('out_ZINC_PE_viz/ZINC_valset_'+PE+'.pdf', bbox_inches='tight') \n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6f1de5dc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot_init_PE_comparison(num_nodes, num_unique_RWPEs, 'RWPE')\n",
+ "plot_init_PE_comparison(num_nodes, num_unique_LapPEs, 'LapPE')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "607a71d7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graph_ids = [212,672] # not equal\n",
+ "graph_ids = [91,967] # equal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "960aa77f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for idx in graph_ids:\n",
+ " fig = plt.figure(dpi=100, figsize=(8, 6))\n",
+ " g_zinc_trial = zinc_d.val[idx][0]\n",
+ " g = g_zinc_trial.to_networkx(node_attrs=['pos_enc'])\n",
+ " #groups = set(nx.get_node_attributes(g,'pos_enc').values())\n",
+ " groups = torch.unique(g_zinc_trial.ndata['pos_enc'],dim=0)\n",
+ " mapping = dict(zip(groups,count()))\n",
+ " nodes = g.nodes()\n",
+ " colors = []\n",
+ " for n in nodes:\n",
+ " for key in mapping.keys():\n",
+ " if torch.equal(key, g.nodes[n]['pos_enc']):\n",
+ " color = mapping[key]\n",
+ " colors.append(color)\n",
+ " \n",
+ " pos = nx.spring_layout(g)\n",
+ " ec = nx.draw_networkx_edges(g, pos, alpha=0.2)\n",
+ " nc = nx.draw_networkx_nodes(g, pos, nodelist=nodes, node_color=colors, \n",
+ " node_size=100, cmap=plt.cm.jet)\n",
+ " plt.colorbar(nc)\n",
+ " #plt.xlabel(\"ZINC Val id: \" +str(idx), fontsize=16)\n",
+ " plt.title(\"nodes: \"+ str(num_nodes[idx]) + \" | unique RWPEs: \"+str(num_unique_RWPEs[idx]), fontsize=16)\n",
+ " #fig.savefig('out_ZINC_PE_viz/ZINC_valset_graph_'+str(idx)+'.pdf', bbox_inches='tight') \n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "59af93f9",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0e552577",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}