diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..62e43e1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +#common +**/*.DS_Store +**/*.ipynb_checkpoints/ +**/__pycache__ +out/ + +#ZINC dataset +data/molecules/*.pkl +data/molecules/*.pickle +data/molecules/*.zip +data/molecules/zinc_full/*.pkl +data/molecules/zinc_full/*.pickle + + +#OGB +dataset/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..44754d3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Vijay Prakash Dwivedi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio and Xavier Bresson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e288bfb --- /dev/null +++ b/README.md @@ -0,0 +1,53 @@ + + +# Graph Neural Networks with
Learnable Structural and Positional Representations + +
+ +Source code for the paper "**Graph Neural Networks with Learnable Structural and Positional Representations**" by Vijay Prakash Dwivedi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio and Xavier Bresson. + +We propose a novel GNN architecture in which the structural and positional representations are decoupled, and are learnt separately to learn these two essential properties. The architecture, named **MPGNNs-LSPE** (MPGNNs with **L**earnable **S**tructural and **P**ositional **E**ncodings), is generic enough that it can be applied to any GNN model of interest which fits into the popular 'message-passing framework', including Transformers. + +![MPGNNs-LSPE](./docs/gnn-lspe.png) + +
+ + +## 1. Repo installation + +[Follow these instructions](./docs/01_repo_installation.md) to install the repo and setup the environment. + + +
+ +## 2. Download datasets + +[Proceed as follows](./docs/02_download_datasets.md) to download the benchmark datasets. + + +
+ +## 3. Reproducibility + +[Use this page](./docs/03_run_codes.md) to run the codes and reproduce the published results. + + +
+ +## 4. Reference + +TODO +[ArXiv's paper](https://arxiv.org/pdf/2110.xxxxx.pdf) +``` +@article{dwivedi2021graph, + title={Graph Neural Networks with Learnable Structural and Positional Representations}, + author={Dwivedi, Vijay Prakash and Luu, Anh Tuan and Laurent, Thomas and Bengio, Yoshua and Bresson, Xavier}, + journal={arXiv preprint arXiv:2110.xxxxx}, + year={2021} +} +``` + + + +


+ diff --git a/configs/GatedGCN_MOLPCBA_LSPE.json b/configs/GatedGCN_MOLPCBA_LSPE.json new file mode 100644 index 0000000..0e79994 --- /dev/null +++ b/configs/GatedGCN_MOLPCBA_LSPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLPCBA", + + "out_dir": "out/GatedGCN_MOLPCBA_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-4, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 96 + }, + + "net_params": { + "L": 8, + "hidden_dim": 118, + "out_dim": 118, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.1, + "batch_norm": true, + "pos_enc_dim": 20, + "pe_init": "rand_walk", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_MOLPCBA_LapPE.json b/configs/GatedGCN_MOLPCBA_LapPE.json new file mode 100644 index 0000000..70a9a0b --- /dev/null +++ b/configs/GatedGCN_MOLPCBA_LapPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLPCBA", + + "out_dir": "out/GatedGCN_MOLPCBA_LapPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-4, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 48 + }, + + "net_params": { + "L": 8, + "hidden_dim": 154, + "out_dim": 154, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.1, + "batch_norm": true, + "pos_enc_dim": 3, + "pe_init": "lap_pe", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_MOLPCBA_NoPE.json b/configs/GatedGCN_MOLPCBA_NoPE.json new file mode 100644 index 0000000..141092e --- /dev/null +++ b/configs/GatedGCN_MOLPCBA_NoPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLPCBA", + + "out_dir": "out/GatedGCN_MOLPCBA_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-4, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 48 + }, + + "net_params": { + "L": 8, + "hidden_dim": 154, + "out_dim": 154, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.1, + "batch_norm": true, + "pos_enc_dim": 3, + "pe_init": "no_pe", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_MOLTOX21_LSPE.json b/configs/GatedGCN_MOLTOX21_LSPE.json new file mode 100644 index 0000000..456bc36 --- /dev/null +++ b/configs/GatedGCN_MOLTOX21_LSPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLTOX21", + + "out_dir": "out/GatedGCN_MOLTOX21_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-5, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 48 + }, + + "net_params": { + "L": 8, + "hidden_dim": 118, + "out_dim": 118, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "batch_norm": true, + "pos_enc_dim": 16, + "pe_init": "rand_walk", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_MOLTOX21_LapPE.json b/configs/GatedGCN_MOLTOX21_LapPE.json new file mode 100644 index 0000000..4b06e81 --- /dev/null +++ b/configs/GatedGCN_MOLTOX21_LapPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLTOX21", + + "out_dir": "out/GatedGCN_MOLTOX21_LapPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-5, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 48 + }, + + "net_params": { + "L": 8, + "hidden_dim": 154, + "out_dim": 154, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "batch_norm": true, + "pos_enc_dim": 3, + "pe_init": "lap_pe", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_MOLTOX21_NoPE.json b/configs/GatedGCN_MOLTOX21_NoPE.json new file mode 100644 index 0000000..e019c71 --- /dev/null +++ b/configs/GatedGCN_MOLTOX21_NoPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "OGBG-MOLTOX21", + + "out_dir": "out/GatedGCN_MOLTOX21_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-5, + "weight_decay": 0.0, + "print_epoch_interval": 1, + "max_time": 48 + }, + + "net_params": { + "L": 8, + "hidden_dim": 154, + "out_dim": 154, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "batch_norm": true, + "pos_enc_dim": 3, + "pe_init": "no_pe", + "use_lapeig_loss": false, + "alpha_loss": 1e-3, + "lambda_loss": 100 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_ZINC_LSPE.json b/configs/GatedGCN_ZINC_LSPE.json new file mode 100644 index 0000000..4875bd6 --- /dev/null +++ b/configs/GatedGCN_ZINC_LSPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "ZINC", + + "out_dir": "out/GatedGCN_ZINC_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 12 + }, + + "net_params": { + "L": 16, + "hidden_dim": 59, + "out_dim": 59, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "batch_norm": true, + "pos_enc_dim": 20, + "pe_init": "rand_walk", + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json b/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json new file mode 100644 index 0000000..c5f6122 --- /dev/null +++ b/configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "ZINC", + + "out_dir": "out/GatedGCN_ZINC_LSPE_withLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 12 + }, + + "net_params": { + "L": 16, + "hidden_dim": 59, + "out_dim": 59, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "batch_norm": true, + "pos_enc_dim": 20, + "pe_init": "rand_walk", + "use_lapeig_loss": true, + "alpha_loss": 1, + "lambda_loss": 1e-1 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_ZINC_LapPE.json b/configs/GatedGCN_ZINC_LapPE.json new file mode 100644 index 0000000..4f82e16 --- /dev/null +++ b/configs/GatedGCN_ZINC_LapPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "ZINC", + + "out_dir": "out/GatedGCN_ZINC_LapPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 12 + }, + + "net_params": { + "L": 16, + "hidden_dim": 78, + "out_dim": 78, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "batch_norm": true, + "pe_init": "lap_pe", + "pos_enc_dim": 8, + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1 + } +} \ No newline at end of file diff --git a/configs/GatedGCN_ZINC_NoPE.json b/configs/GatedGCN_ZINC_NoPE.json new file mode 100644 index 0000000..cdd8cfc --- /dev/null +++ b/configs/GatedGCN_ZINC_NoPE.json @@ -0,0 +1,41 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GatedGCN", + "dataset": "ZINC", + + "out_dir": "out/GatedGCN_ZINC_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 12 + }, + + "net_params": { + "L": 16, + "hidden_dim": 78, + "out_dim": 78, + "residual": true, + "edge_feat": true, + "readout": "mean", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "batch_norm": true, + "pe_init": "no_pe", + "pos_enc_dim": 16, + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1 + } +} \ No newline at end of file diff --git a/configs/GraphiT_MOLTOX21_LSPE.json b/configs/GraphiT_MOLTOX21_LSPE.json new file mode 100644 index 0000000..5fc8be5 --- /dev/null +++ b/configs/GraphiT_MOLTOX21_LSPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "GraphiT", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/GraphiT_MOLTOX21_LSPE_noLapEigLoss/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "full_graph": true, + + "L": 10, + "hidden_dim": 64, + "out_dim": 64, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.5, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "rand_walk", + "pos_enc_dim": 16, + "adaptive_edge_PE": true, + "p_steps": 16, + "gamma": 0.25 + } +} \ No newline at end of file diff --git a/configs/GraphiT_MOLTOX21_NoPE.json b/configs/GraphiT_MOLTOX21_NoPE.json new file mode 100644 index 0000000..4bb2c58 --- /dev/null +++ b/configs/GraphiT_MOLTOX21_NoPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "GraphiT", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/GraphiT_MOLTOX21_NoPE/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "full_graph": true, + + "L": 10, + "hidden_dim": 88, + "out_dim": 88, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.5, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "no_pe", + "pos_enc_dim": 12, + "adaptive_edge_PE": true, + "p_steps": 16, + "gamma": 0.25 + } +} \ No newline at end of file diff --git a/configs/GraphiT_ZINC_LSPE.json b/configs/GraphiT_ZINC_LSPE.json new file mode 100644 index 0000000..cb0235d --- /dev/null +++ b/configs/GraphiT_ZINC_LSPE.json @@ -0,0 +1,49 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GraphiT", + "dataset": "ZINC", + + "out_dir": "out/GraphiT_ZINC_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 32, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "full_graph": true, + + "L": 10, + "hidden_dim": 48, + "out_dim": 48, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "rand_walk", + "pos_enc_dim": 16, + "adaptive_edge_PE": true, + "p_steps": 16, + "gamma": 0.25 + } +} \ No newline at end of file diff --git a/configs/GraphiT_ZINC_NoPE.json b/configs/GraphiT_ZINC_NoPE.json new file mode 100644 index 0000000..744ab5f --- /dev/null +++ b/configs/GraphiT_ZINC_NoPE.json @@ -0,0 +1,49 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "GraphiT", + "dataset": "ZINC", + + "out_dir": "out/GraphiT_ZINC_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 32, + "init_lr": 0.0003, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "full_graph": true, + + "L": 10, + "hidden_dim": 64, + "out_dim": 64, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "no_pe", + "pos_enc_dim": 16, + "adaptive_edge_PE": true, + "p_steps": 16, + "gamma": 0.25 + } +} \ No newline at end of file diff --git a/configs/PNA_MOLPCBA_LSPE.json b/configs/PNA_MOLPCBA_LSPE.json new file mode 100644 index 0000000..0043783 --- /dev/null +++ b/configs/PNA_MOLPCBA_LSPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "PNA", + "dataset": "OGBG-MOLPCBA", + "out_dir":"out/PNA_MOLPCBA_LSPE_noLapEigLoss/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 512, + "init_lr": 0.0005, + "lr_reduce_factor": 0.8, + "lr_schedule_patience": 10, + "min_lr": 2e-5, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "L": 4, + "hidden_dim": 322, + "out_dim": 322, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "dropout_2": 0.1, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean sum max", + "scalers": "identity", + "gru": false, + "edge_dim": 16, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "use_lapeig_loss": false, + "alpha_loss": 1e-2, + "lambda_loss": 1e-2, + "pe_init": "rand_walk", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_MOLPCBA_NoPE.json b/configs/PNA_MOLPCBA_NoPE.json new file mode 100644 index 0000000..d58c8d2 --- /dev/null +++ b/configs/PNA_MOLPCBA_NoPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "PNA", + "dataset": "OGBG-MOLPCBA", + "out_dir":"out/PNA_MOLPCBA_NoPE/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 512, + "init_lr": 0.0005, + "lr_reduce_factor": 0.8, + "lr_schedule_patience": 4, + "min_lr": 2e-5, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "L": 4, + "hidden_dim": 510, + "out_dim": 510, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.2, + "dropout_2": 0.0, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean sum max", + "scalers": "identity", + "gru": false, + "edge_dim": 16, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "use_lapeig_loss": false, + "alpha_loss": 1e-2, + "lambda_loss": 1e-2, + "pe_init": "no_pe", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_MOLTOX21_LSPE.json b/configs/PNA_MOLTOX21_LSPE.json new file mode 100644 index 0000000..8a7e9f9 --- /dev/null +++ b/configs/PNA_MOLTOX21_LSPE.json @@ -0,0 +1,48 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "PNA", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/PNA_MOLTOX21_LSPE_noLapEigLoss/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.0005, + "lr_reduce_factor": 0.8, + "lr_schedule_patience": 10, + "min_lr": 2e-5, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "L": 8, + "hidden_dim": 140, + "out_dim": 140, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "dropout_2": 0.1, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean max min std", + "scalers": "identity amplification attenuation", + "gru": false, + "edge_dim": 50, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "lpe_variant": "native_lpe", + "use_lapeig_loss": false, + "alpha_loss": 1e-2, + "lambda_loss": 1e-2, + "pe_init": "rand_walk", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json b/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json new file mode 100644 index 0000000..58ef016 --- /dev/null +++ b/configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json @@ -0,0 +1,48 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "PNA", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/PNA_MOLTOX21_LSPE_withLapEigLoss/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.0005, + "lr_reduce_factor": 0.8, + "lr_schedule_patience": 10, + "min_lr": 2e-5, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "L": 8, + "hidden_dim": 140, + "out_dim": 140, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "dropout_2": 0.1, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean max min std", + "scalers": "identity amplification attenuation", + "gru": false, + "edge_dim": 50, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "lpe_variant": "native_lpe", + "use_lapeig_loss": true, + "alpha_loss": 1e-1, + "lambda_loss": 100, + "pe_init": "rand_walk", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_MOLTOX21_NoPE.json b/configs/PNA_MOLTOX21_NoPE.json new file mode 100644 index 0000000..8619000 --- /dev/null +++ b/configs/PNA_MOLTOX21_NoPE.json @@ -0,0 +1,48 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "PNA", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/PNA_MOLTOX21_NoPE/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 256, + "init_lr": 0.0005, + "lr_reduce_factor": 0.8, + "lr_schedule_patience": 10, + "min_lr": 2e-5, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "L": 8, + "hidden_dim": 206, + "out_dim": 206, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.4, + "dropout_2": 0.1, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean max min std", + "scalers": "identity amplification attenuation", + "gru": false, + "edge_dim": 50, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "lpe_variant": "native_lpe", + "use_lapeig_loss": false, + "alpha_loss": 1e-2, + "lambda_loss": 1e-2, + "pe_init": "no_pe", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_ZINC_LSPE.json b/configs/PNA_ZINC_LSPE.json new file mode 100644 index 0000000..9a4a674 --- /dev/null +++ b/configs/PNA_ZINC_LSPE.json @@ -0,0 +1,51 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "PNA", + "dataset": "ZINC", + + "out_dir": "out/PNA_ZINC_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "L": 16, + "hidden_dim": 55, + "out_dim": 55, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean max min std", + "scalers": "identity amplification attenuation", + "towers": 5, + "divide_input_first": true, + "divide_input_last": true, + "gru": false, + "edge_dim": 40, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "rand_walk", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/PNA_ZINC_NoPE.json b/configs/PNA_ZINC_NoPE.json new file mode 100644 index 0000000..67ef1ab --- /dev/null +++ b/configs/PNA_ZINC_NoPE.json @@ -0,0 +1,51 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "PNA", + "dataset": "ZINC", + + "out_dir": "out/PNA_ZINC_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.001, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 3e-6, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "L": 16, + "hidden_dim": 70, + "out_dim": 70, + "residual": true, + "edge_feat": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "graph_norm": true, + "batch_norm": true, + "aggregators": "mean max min std", + "scalers": "identity amplification attenuation", + "towers": 5, + "divide_input_first": true, + "divide_input_last": true, + "gru": false, + "edge_dim": 40, + "pretrans_layers" : 1, + "posttrans_layers" : 1, + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "no_pe", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/SAN_MOLTOX21_LSPE.json b/configs/SAN_MOLTOX21_LSPE.json new file mode 100644 index 0000000..0c15005 --- /dev/null +++ b/configs/SAN_MOLTOX21_LSPE.json @@ -0,0 +1,45 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "SAN", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/SAN_MOLTOX21_LSPE_noLapEigLoss/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "full_graph": true, + "init_gamma": 0.1, + + "L": 10, + "hidden_dim": 64, + "out_dim": 64, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.5, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "rand_walk", + "pos_enc_dim": 12 + } +} \ No newline at end of file diff --git a/configs/SAN_MOLTOX21_NoPE.json b/configs/SAN_MOLTOX21_NoPE.json new file mode 100644 index 0000000..9922214 --- /dev/null +++ b/configs/SAN_MOLTOX21_NoPE.json @@ -0,0 +1,45 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + "model": "SAN", + "dataset": "OGBG-MOLTOX21", + "out_dir":"out/SAN_MOLTOX21_NoPE/", + + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 128, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 96 + }, + "net_params": { + "full_graph": true, + "init_gamma": 0.1, + + "L": 10, + "hidden_dim": 88, + "out_dim": 88, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.5, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "no_pe", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/SAN_ZINC_LSPE.json b/configs/SAN_ZINC_LSPE.json new file mode 100644 index 0000000..9e8dd59 --- /dev/null +++ b/configs/SAN_ZINC_LSPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "SAN", + "dataset": "ZINC", + + "out_dir": "out/SAN_ZINC_LSPE_noLapEigLoss/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 32, + "init_lr": 0.0007, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "full_graph": true, + "init_gamma": 0.1, + + "L": 10, + "hidden_dim": 48, + "out_dim": 48, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "rand_walk", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/configs/SAN_ZINC_NoPE.json b/configs/SAN_ZINC_NoPE.json new file mode 100644 index 0000000..2dd8619 --- /dev/null +++ b/configs/SAN_ZINC_NoPE.json @@ -0,0 +1,47 @@ +{ + "gpu": { + "use": true, + "id": 0 + }, + + "model": "SAN", + "dataset": "ZINC", + + "out_dir": "out/SAN_ZINC_NoPE/", + + "params": { + "seed": 41, + "epochs": 1000, + "batch_size": 32, + "init_lr": 0.0003, + "lr_reduce_factor": 0.5, + "lr_schedule_patience": 25, + "min_lr": 1e-6, + "weight_decay": 0.0, + "print_epoch_interval": 5, + "max_time": 48 + }, + + "net_params": { + "full_graph": true, + "init_gamma": 0.1, + + "L": 10, + "hidden_dim": 64, + "out_dim": 64, + "n_heads": 8, + + "residual": true, + "readout": "sum", + "in_feat_dropout": 0.0, + "dropout": 0.0, + "layer_norm": false, + "batch_norm": true, + + "use_lapeig_loss": false, + "alpha_loss": 1e-4, + "lambda_loss": 1000, + "pe_init": "no_pe", + "pos_enc_dim": 16 + } +} \ No newline at end of file diff --git a/data/data.py b/data/data.py new file mode 100644 index 0000000..3871142 --- /dev/null +++ b/data/data.py @@ -0,0 +1,21 @@ +""" + File to load dataset based on user control from main file +""" + +from data.molecules import MoleculeDataset +from data.ogb_mol import OGBMOLDataset + +def LoadData(DATASET_NAME): + """ + This function is called in the main.py file + returns: + ; dataset object + """ + + # handling for (ZINC) molecule dataset + if DATASET_NAME == 'ZINC' or DATASET_NAME == 'ZINC-full': + return MoleculeDataset(DATASET_NAME) + + # handling for MOLPCBA and MOLTOX21 dataset + if DATASET_NAME in ['OGBG-MOLPCBA', 'OGBG-MOLTOX21']: + return OGBMOLDataset(DATASET_NAME) \ No newline at end of file diff --git a/data/molecules.py b/data/molecules.py new file mode 100644 index 0000000..a80eb40 --- /dev/null +++ b/data/molecules.py @@ -0,0 +1,331 @@ +import torch +import torch.nn.functional as F +import pickle +import torch.utils.data +import time +import os +import numpy as np + +import csv + +import dgl + +from scipy import sparse as sp +import numpy as np +import networkx as nx + +# The dataset pickle and index files are in ./data/molecules/ dir +# [.pickle and .index; for split 'train', 'val' and 'test'] + + + + +class MoleculeDGL(torch.utils.data.Dataset): + def __init__(self, data_dir, split, num_graphs=None): + self.data_dir = data_dir + self.split = split + self.num_graphs = num_graphs + + with open(data_dir + "/%s.pickle" % self.split,"rb") as f: + self.data = pickle.load(f) + + if self.num_graphs in [10000, 1000]: + # loading the sampled indices from file ./zinc_molecules/.index + with open(data_dir + "/%s.index" % self.split,"r") as f: + data_idx = [list(map(int, idx)) for idx in csv.reader(f)] + self.data = [ self.data[i] for i in data_idx[0] ] + + assert len(self.data)==num_graphs, "Sample num_graphs again; available idx: train/val/test => 10k/1k/1k" + + """ + data is a list of Molecule dict objects with following attributes + + molecule = data[idx] + ; molecule['num_atom'] : nb of atoms, an integer (N) + ; molecule['atom_type'] : tensor of size N, each element is an atom type, an integer between 0 and num_atom_type + ; molecule['bond_type'] : tensor of size N x N, each element is a bond type, an integer between 0 and num_bond_type + ; molecule['logP_SA_cycle_normalized'] : the chemical property to regress, a float variable + """ + + self.graph_lists = [] + self.graph_labels = [] + self.n_samples = len(self.data) + self._prepare() + + def _prepare(self): + print("preparing %d graphs for the %s set..." % (self.num_graphs, self.split.upper())) + + for molecule in self.data: + node_features = molecule['atom_type'].long() + + adj = molecule['bond_type'] + edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list + + edge_idxs_in_adj = edge_list.split(1, dim=1) + edge_features = adj[edge_idxs_in_adj].reshape(-1).long() + + # Create the DGL Graph + g = dgl.DGLGraph() + g.add_nodes(molecule['num_atom']) + g.ndata['feat'] = node_features + + for src, dst in edge_list: + g.add_edges(src.item(), dst.item()) + g.edata['feat'] = edge_features + + self.graph_lists.append(g) + self.graph_labels.append(molecule['logP_SA_cycle_normalized']) + + def __len__(self): + """Return the number of graphs in the dataset.""" + return self.n_samples + + def __getitem__(self, idx): + """ + Get the idx^th sample. + Parameters + --------- + idx : int + The sample index. + Returns + ------- + (dgl.DGLGraph, int) + DGLGraph with node feature stored in `feat` field + And its label. + """ + return self.graph_lists[idx], self.graph_labels[idx] + + +class MoleculeDatasetDGL(torch.utils.data.Dataset): + def __init__(self, name='Zinc'): + t0 = time.time() + self.name = name + + self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well + self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well + + data_dir='./data/molecules' + + if self.name == 'ZINC-full': + data_dir='./data/molecules/zinc_full' + self.train = MoleculeDGL(data_dir, 'train', num_graphs=220011) + self.val = MoleculeDGL(data_dir, 'val', num_graphs=24445) + self.test = MoleculeDGL(data_dir, 'test', num_graphs=5000) + else: + self.train = MoleculeDGL(data_dir, 'train', num_graphs=10000) + self.val = MoleculeDGL(data_dir, 'val', num_graphs=1000) + self.test = MoleculeDGL(data_dir, 'test', num_graphs=1000) + print("Time taken: {:.4f}s".format(time.time()-t0)) + + + +def add_eig_vec(g, pos_enc_dim): + """ + Graph positional encoding v/ Laplacian eigenvectors + This func is for eigvec visualization, same code as positional_encoding() func, + but stores value in a diff key 'eigvec' + """ + + # Laplacian + A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(g.number_of_nodes()) - N * A * N + + # Eigenvectors with numpy + EigVal, EigVec = np.linalg.eig(L.toarray()) + idx = EigVal.argsort() # increasing order + EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) + g.ndata['eigvec'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() + + # zero padding to the end if n < pos_enc_dim + n = g.number_of_nodes() + if n <= pos_enc_dim: + g.ndata['eigvec'] = F.pad(g.ndata['eigvec'], (0, pos_enc_dim - n + 1), value=float('0')) + + return g + + +def lap_positional_encoding(g, pos_enc_dim): + """ + Graph positional encoding v/ Laplacian eigenvectors + """ + + # Laplacian + A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(g.number_of_nodes()) - N * A * N + + # Eigenvectors with numpy + EigVal, EigVec = np.linalg.eig(L.toarray()) + idx = EigVal.argsort() # increasing order + EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) + g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() + g.ndata['eigvec'] = g.ndata['pos_enc'] + + # # Eigenvectors with scipy + # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') + # EigVec = EigVec[:, EigVal.argsort()] # increasing order + # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() + + return g + + +def init_positional_encoding(g, pos_enc_dim, type_init): + """ + Initializing positional encoding with RWPE + """ + + n = g.number_of_nodes() + + if type_init == 'rand_walk': + # Geometric diffusion features with Random Walk + A = g.adjacency_matrix(scipy_fmt="csr") + Dinv = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -1.0, dtype=float) # D^-1 + RW = A * Dinv + M = RW + + # Iterate + nb_pos_enc = pos_enc_dim + PE = [torch.from_numpy(M.diagonal()).float()] + M_power = M + for _ in range(nb_pos_enc-1): + M_power = M_power * M + PE.append(torch.from_numpy(M_power.diagonal()).float()) + PE = torch.stack(PE,dim=-1) + g.ndata['pos_enc'] = PE + + return g + + +def make_full_graph(g, adaptive_weighting=None): + + full_g = dgl.from_networkx(nx.complete_graph(g.number_of_nodes())) + + #Here we copy over the node feature data and laplace encodings + full_g.ndata['feat'] = g.ndata['feat'] + + try: + full_g.ndata['pos_enc'] = g.ndata['pos_enc'] + except: + pass + + try: + full_g.ndata['eigvec'] = g.ndata['eigvec'] + except: + pass + + #Populate edge features w/ 0s + full_g.edata['feat']=torch.zeros(full_g.number_of_edges(), dtype=torch.long) + full_g.edata['real']=torch.zeros(full_g.number_of_edges(), dtype=torch.long) + + #Copy real edge data over + full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['feat'] = g.edata['feat'] + full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['real'] = torch.ones(g.edata['feat'].shape[0], dtype=torch.long) + + + # This code section only apply for GraphiT -------------------------------------------- + if adaptive_weighting is not None: + p_steps, gamma = adaptive_weighting + + n = g.number_of_nodes() + A = g.adjacency_matrix(scipy_fmt="csr") + + # Adaptive weighting k_ij for each edge + if p_steps == "qtr_num_nodes": + p_steps = int(0.25*n) + elif p_steps == "half_num_nodes": + p_steps = int(0.5*n) + elif p_steps == "num_nodes": + p_steps = int(n) + elif p_steps == "twice_num_nodes": + p_steps = int(2*n) + + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + I = sp.eye(n) + L = I - N * A * N + + k_RW = I - gamma*L + k_RW_power = k_RW + for _ in range(p_steps - 1): + k_RW_power = k_RW_power.dot(k_RW) + + k_RW_power = torch.from_numpy(k_RW_power.toarray()) + + # Assigning edge features k_RW_eij for adaptive weighting during attention + full_edge_u, full_edge_v = full_g.edges() + num_edges = full_g.number_of_edges() + + k_RW_e_ij = [] + for edge in range(num_edges): + k_RW_e_ij.append(k_RW_power[full_edge_u[edge], full_edge_v[edge]]) + + full_g.edata['k_RW'] = torch.stack(k_RW_e_ij,dim=-1).unsqueeze(-1).float() + # -------------------------------------------------------------------------------------- + + return full_g + + +class MoleculeDataset(torch.utils.data.Dataset): + + def __init__(self, name): + """ + Loading ZINC datasets + """ + start = time.time() + print("[I] Loading dataset %s..." % (name)) + self.name = name + data_dir = 'data/molecules/' + with open(data_dir+name+'.pkl',"rb") as f: + f = pickle.load(f) + self.train = f[0] + self.val = f[1] + self.test = f[2] + self.num_atom_type = f[3] + self.num_bond_type = f[4] + print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) + print("[I] Finished loading.") + print("[I] Data load time: {:.4f}s".format(time.time()-start)) + + + # form a mini batch from a given list of samples = [(graph, label) pairs] + def collate(self, samples): + # The input samples is a list of pairs (graph, label). + graphs, labels = map(list, zip(*samples)) + labels = torch.tensor(np.array(labels)).unsqueeze(1) + tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] + tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] + snorm_n = torch.cat(tab_snorm_n).sqrt() + batched_graph = dgl.batch(graphs) + + return batched_graph, labels, snorm_n + + + def _add_lap_positional_encodings(self, pos_enc_dim): + + # Graph positional encoding v/ Laplacian eigenvectors + self.train.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists] + self.val.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists] + self.test.graph_lists = [lap_positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists] + + def _add_eig_vecs(self, pos_enc_dim): + + # This is used if we visualize the eigvecs + self.train.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.train.graph_lists] + self.val.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.val.graph_lists] + self.test.graph_lists = [add_eig_vec(g, pos_enc_dim) for g in self.test.graph_lists] + + def _init_positional_encodings(self, pos_enc_dim, type_init): + + # Initializing positional encoding randomly with l2-norm 1 + self.train.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.train.graph_lists] + self.val.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.val.graph_lists] + self.test.graph_lists = [init_positional_encoding(g, pos_enc_dim, type_init) for g in self.test.graph_lists] + + def _make_full_graph(self, adaptive_weighting=None): + self.train.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.train.graph_lists] + self.val.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.val.graph_lists] + self.test.graph_lists = [make_full_graph(g, adaptive_weighting) for g in self.test.graph_lists] + + + + diff --git a/data/molecules/test.index b/data/molecules/test.index new file mode 100644 index 0000000..428bc26 --- /dev/null +++ b/data/molecules/test.index @@ -0,0 +1 @@ +912,204,2253,2006,1828,1143,839,4467,712,4837,3456,260,244,767,1791,1905,4139,4931,217,4597,1628,4464,3436,1805,3679,4827,2278,53,1307,3462,2787,2276,1273,1763,2757,837,759,3112,792,2940,2817,4945,2166,355,3763,4392,1022,3100,645,4522,2401,2962,4729,1575,569,375,1866,2370,653,1907,827,3113,2277,3714,2988,1332,3032,2910,1716,2187,584,4990,1401,4375,2005,1338,3786,3108,2211,4562,1799,2656,458,1876,262,2584,3286,2193,542,1728,4646,2577,1741,4089,3241,3758,1170,2169,2020,4598,4415,2152,4788,3509,4780,3271,2965,1796,1133,4174,4042,744,385,898,1252,1310,3458,4885,520,3152,3126,4881,3834,4334,2059,4532,94,938,4398,2185,2786,913,2404,3561,1295,3716,26,2157,4100,1463,4158,871,2444,4988,1629,3063,1323,4418,4344,4,4906,2655,4002,159,916,2973,2519,1961,474,1973,4647,701,3981,566,4363,1030,1051,3893,4503,1352,2171,4322,4969,3466,1735,4417,1647,2553,3268,3059,3588,4239,3698,991,2030,1840,524,2769,172,4819,4537,1885,4820,1804,58,581,482,1875,552,257,2706,580,4211,1949,2281,3976,1755,1083,4677,4720,3872,1990,3874,3334,1559,772,794,3531,2902,3469,3367,3825,443,806,496,3298,2779,895,2036,1569,1558,4393,3675,1148,1503,3789,2046,617,3630,4508,802,414,4428,120,764,1936,1362,3329,3978,3943,1751,3285,480,1348,3104,17,3198,2172,3727,2336,3465,4552,3986,1268,1555,2430,1783,479,4744,4441,499,2569,468,410,4785,3905,4119,4350,1289,465,4160,656,1522,561,4874,556,1926,3307,982,4666,2016,4742,4870,325,671,3434,4781,4630,4282,2591,2136,1673,2573,1955,2175,3242,1072,2457,3745,2590,594,76,3754,4612,819,600,4404,1746,4144,1085,2859,563,2001,3027,2334,1292,3589,4450,2478,4333,64,4543,2452,848,1100,945,876,2231,2308,4954,1725,2808,1667,2162,4140,2057,416,756,2266,361,29,2732,1071,2145,3619,4519,3503,4594,79,616,1221,4469,295,3024,4771,4526,1213,3520,1044,342,2525,2987,326,2931,1720,2044,842,2897,4586,1266,1939,1331,1450,3377,203,1469,2721,3372,2032,1304,885,3133,317,3855,1822,1634,3770,2864,2500,1864,1826,193,1582,3264,2689,2282,568,2286,2876,4173,3274,2712,226,944,2139,1462,4756,2174,313,888,4887,3559,2831,3574,4966,4189,947,3155,4723,1557,2086,363,3572,13,4259,4410,1614,2983,3533,573,2704,2571,1020,2460,4154,2533,3345,2672,3296,2422,4541,1042,1571,3444,3105,1425,4662,2465,3326,4488,3,2489,2350,1721,3521,4751,2639,3809,3622,1750,4187,3876,1390,694,2324,4222,2745,765,1924,2542,1631,1207,200,378,3892,596,3730,3395,4715,1592,3145,4049,3273,1998,1208,45,873,3482,1792,1440,4243,3805,411,4566,2041,994,3739,1092,3806,4351,4578,4877,2599,3625,4135,3495,3652,1303,3888,3686,2123,2025,2271,4270,3969,1959,2249,3603,634,2340,1920,2225,2751,2619,4424,660,1235,1894,3137,1251,1752,526,3398,3339,2710,4445,3816,3406,510,1694,3441,3190,4784,160,4716,3116,3907,48,2881,2446,3194,3432,4409,4473,4941,1806,3999,1797,2235,3570,237,3185,2753,3312,3828,1045,220,3227,4848,4623,222,687,3511,1111,3782,1488,2131,2681,1733,3724,2677,2764,2279,3453,2066,670,3852,158,426,2866,1836,562,329,254,1633,166,1248,1954,1034,3879,937,4620,1785,2099,3021,1374,4963,4974,1341,2548,4740,210,2555,3074,3249,1624,622,4850,1989,834,2470,4918,4636,336,2844,4364,3035,564,2795,103,4015,864,3551,2967,3766,1253,3567,1442,4274,2212,4408,3960,3808,3568,4853,2198,2640,2011,709,2284,3692,1997,4668,4999,2755,235,2662,1489,3994,1737,2906,2116,2788,2290,4883,2263,4553,83,4232,1565,1977,4548,1968,3900,4020,3671,141,762,2410,1815,1993,2508,4764,3023,4534,4349,2816,3485,2709,2882,3717,2219,2511,1888,988,1577,979,4389,1516,1772,3966,2265,4829,4297,4888,2318,823,1590,2426,1863,2956,2476,115,1036,2247,372,446,4533,2393,4021,840,100,4702,2329,3845,3921,3608,2791,1510,420,2068,3913,934,535,3282,4028,606,4726,439,1242,1222,4610,697,2033,970,4571,3409,1848,4280,3690,3626,2435,4821,3512,2501,4658,493,4994,812,1702,2167,665,1286,1964,1423,4521,614,1282,21,3346,4864,3849,2385,267,1896,2360,2316,3719,583,1912,4831,1620,940,4461,1841,1220,2176,1165,488,1359,2364,3597,1018,3839,2491,3297,2230,4099,4423,4045,3586,658,4899,3539,2051,211,748,4712,4809,169,2207,1435,3854,4251,1486,4795,747,3850,2850,2730,2630,856,1317,2701,4058,2361,3280,4506,300,3725,721,2576,2067,2648,949,3311,4215,9,4444,3784,3385,444,1536,4247,2963,4083,3621,422,4499,1073,2359,3970,236,4987,1960,1297,2545,4512,112,4524,3342,763,929,3780,962,1261,4082,2390,4168,2239,3403,3952,3868,1996,3741,4515,1184,3142,1561,4910,4163,1118,571,3399,2784,4159,2188,2317,2445,4808,4750,4011,1217,3658,4412,3967,2827,2723,4451,3090,2636,1545,1956,4684,1913,3365,357,2606,3123,3162,1246,4057,303,4114,4834,2719,822,3606,816,4308,3743,125,1180,3358,1264,612,3846,2773,3256,657,2691,4371,2594,3997,4432,294,560,1923,2354,740,3555,3634,2453,376,2657,459,2403,2936,3070,3528,1192,2000,3375,1475,1392,1434,646,4992,1972,4076,4777,1172,1902,3777,2080,3764,2091,3811,2356,4477,1294,605,3618,2830,4813,2450,3475,2048,3742,2474,3151,3958,1943,3124,4685,4708,2423,2418,179,2248 diff --git a/data/molecules/train.index b/data/molecules/train.index new file mode 100644 index 0000000..c4b804d --- /dev/null +++ b/data/molecules/train.index @@ -0,0 +1 @@ +167621,29184,6556,194393,72097,64196,58513,36579,193061,26868,177392,194161,142964,22790,154794,110604,8331,7811,24561,57314,60990,132475,157815,6956,147127,52124,187700,170363,183848,142853,109974,57787,117757,154472,72926,212187,1703,198916,211240,41853,183013,110785,89194,72842,40758,56443,200145,88236,26793,24312,99595,25353,94104,90165,158263,69342,211583,11390,191294,120435,140568,32722,99230,20656,144714,76854,217423,164794,162141,94800,151349,50407,184699,18233,12012,173346,59742,202655,75861,20916,61024,26476,99647,72869,118858,166640,218657,95638,42638,97040,93132,54921,175682,69986,183977,179187,169878,18717,159680,166455,44862,140021,191136,64175,42834,121178,99471,70765,167772,180397,146001,57570,179467,85008,201408,203423,14663,60043,215430,8414,211037,82694,105162,70186,17350,55307,148682,188196,82490,55738,171819,130870,103712,168519,120285,37452,69436,36603,64651,195294,147159,141289,68876,195825,153245,112311,152969,104700,94895,57493,36262,133569,129372,23831,198123,12351,28743,40066,164481,41938,207638,178384,110666,156345,16653,100864,100039,156208,122696,138704,65906,145024,3009,178332,188932,30029,178706,140763,196838,69946,201483,168024,89174,29242,76939,113971,41460,118940,850,189292,188659,69045,131225,199743,46832,133085,27894,163918,78235,167496,133080,159637,52143,40065,98019,199887,42349,141394,204112,139029,149,157009,84975,128085,5105,29325,95153,218016,211491,80612,62770,15184,63143,148729,20645,22453,191865,127399,213915,18143,199387,139645,200758,32967,33657,172949,124592,144127,43287,69483,138326,159014,110923,55521,141373,197988,191347,180844,52730,186895,81714,104593,176078,170361,97889,114845,135679,118354,31720,64986,58903,16784,88627,5514,154221,145207,60323,154256,57728,1885,18610,185556,165439,15433,60015,17668,8234,86619,18574,134782,62391,73001,175368,127248,56160,141356,34684,189622,149695,151050,123907,63700,205683,123987,211680,106708,49914,24726,25409,172748,112997,92876,111038,107767,122427,191122,14200,176518,171299,169392,25799,15889,105544,190896,88946,209870,28644,65183,50224,49862,140584,117601,36747,110593,48100,73018,121275,65485,19761,116164,211818,144264,25666,13261,170955,141711,219159,3868,24448,197542,61965,43597,106539,127307,126185,56032,105130,15370,43158,99345,565,102346,69521,205539,205816,119277,74776,110888,182607,191497,205353,145691,173505,188326,127577,40579,49780,77780,57068,15331,151828,192869,142133,15979,196077,82209,14985,13144,153138,124987,131819,139231,41270,14910,133124,21000,48712,17962,155984,17815,177002,61657,105847,31427,149336,64543,151760,155849,10418,162367,21491,109897,172326,153006,148170,137044,82934,68358,53545,175564,187745,82361,62570,69629,103752,34308,176079,169214,78642,119858,82883,197096,19016,2441,120136,162833,147585,26209,19204,140937,55877,132614,69520,34722,91490,18033,64037,96869,74707,41352,114867,218561,142401,184428,79303,160347,211576,171435,138658,2050,175076,214198,145385,78480,173903,27154,35203,69328,30258,28058,194620,40749,71394,73860,158552,55215,188117,89884,53371,180223,166261,69201,132489,128065,65829,13316,24195,166273,111037,217408,72530,11557,929,87439,202144,34293,167015,68670,42357,194309,115824,144619,184986,112115,147038,2534,29327,19724,181146,39073,143023,9444,218784,96787,152701,144841,38821,112666,33409,10965,80808,95591,208698,10458,93797,55070,178799,65412,174832,26946,92714,204502,146770,106529,162702,196471,40515,62059,42598,209685,212538,46413,108080,6497,47018,193085,87080,205097,107928,210301,175612,192690,212533,65055,69941,41732,206405,183835,28336,100281,10151,123388,58309,52316,214063,120665,91660,80003,215098,208495,59662,58438,6203,173023,50627,104455,86051,73034,18198,202722,73170,92050,168160,133537,104773,178131,140565,86808,7235,30236,68475,46810,152198,69590,10028,28417,156387,113918,90619,190983,206157,82228,114398,158914,134066,30315,100976,151149,49828,66773,11635,185803,114309,443,136293,211421,141151,180055,188594,194497,193210,175801,51651,95478,113062,18343,174125,86559,163356,82291,32669,188679,78727,132939,81077,174821,107057,85506,105486,182769,77504,145335,33367,50289,110217,174307,99390,177554,196118,45620,161353,149187,78892,106450,143638,218556,106,79658,75212,55098,112692,205981,152039,159032,171627,84475,121893,115811,115909,177111,56020,134001,124042,208072,208673,192927,44483,172713,22228,74392,135122,174025,165921,162335,87867,24480,214544,196906,61569,176369,81374,58888,211435,52200,38627,6402,12114,64184,124554,160241,201454,19091,119384,108643,165089,150908,50970,188310,182545,100657,129598,104766,63959,38684,171981,180256,1453,196860,201862,27941,204058,111449,57367,46107,210792,182429,135779,121778,13164,146120,65325,31813,119658,34954,210086,121803,175001,139233,146518,156094,83177,197984,116017,160603,213649,188553,132324,111867,217728,143621,116893,41722,194944,124433,117983,67945,197073,64812,167159,72695,200753,203862,136655,127034,164298,62716,71984,115309,20311,187051,74901,61471,71228,88040,83809,141597,21122,36273,39539,60623,100410,181914,40056,185183,56086,16837,108755,106849,86738,142242,122139,108992,16322,54220,218337,110138,102099,201797,153112,182327,5120,200696,150913,99714,125037,1545,92211,78279,197518,102232,219081,109843,141091,195956,192579,143165,209680,158139,57812,127986,57520,71548,114251,127308,7608,101936,88114,175341,178032,209228,105989,189839,43265,122523,33456,163120,140017,7069,103290,155161,147951,173801,7104,22006,168492,112358,35572,121031,47639,13181,68198,99379,85813,55485,119196,85680,88473,199551,99385,72943,197134,218083,110510,66131,218866,21471,123288,5081,196352,141405,13653,91740,58778,170431,17987,204795,170853,10553,197717,8134,64823,52261,219998,5342,162879,39946,62533,33088,124141,175494,29986,147841,57138,121905,183360,67172,201037,96703,43984,158831,159186,196064,188313,30024,203893,214774,42930,81536,28337,151698,6731,81777,150941,177562,98384,103980,187436,51990,19922,155214,181040,217729,164427,63661,26712,182764,202501,79058,179374,157394,211164,31733,208726,148355,205163,10766,91017,139655,112296,173413,97142,18076,132634,169749,89451,3316,110115,215569,128503,27666,113645,94945,166614,217240,120517,185416,40105,114160,46173,192360,136772,170513,70800,161457,211864,141078,203067,126745,121863,114183,216461,191634,155313,70358,84490,64355,217771,22718,73119,118175,63927,196732,121820,149382,159994,175161,99349,88184,7523,129579,85200,47668,127808,55605,93015,209146,67725,89217,73310,156278,183811,72422,145697,2661,135430,50084,22442,63270,188763,106542,128077,145536,198748,62999,181039,124804,169319,186605,128665,117486,207862,4520,24393,77132,58090,106011,181347,63780,80270,174053,152451,96736,124062,145089,139177,90113,111543,195542,144281,86714,92225,184249,118946,71019,80379,65903,60434,31629,189080,50484,82719,31340,194741,140473,199803,180922,48535,50211,56723,193620,126929,72482,189945,154556,199283,137530,156445,74186,26352,218268,50886,77659,59633,94602,47039,79237,3708,185602,140020,33182,71909,11931,14293,145059,76581,182823,33103,167213,197339,128680,26892,3215,150487,74537,123049,125492,115466,89314,48329,13468,66185,125233,29907,215512,17129,105043,128908,19420,151262,165005,179949,14053,39774,39111,212636,147545,79648,22329,65061,31051,146295,200394,109097,158942,156271,207287,162114,59162,203343,136989,99713,118099,116056,77949,154278,112415,80053,149062,162798,15789,159811,194009,26013,199946,54470,163975,55318,69375,173127,21282,41171,62879,45564,144701,19677,41034,701,107090,118096,180714,155664,123184,76351,8556,60680,75524,185324,74112,184283,119021,18659,180193,61190,69351,206524,207420,163855,154609,173325,210739,51857,111447,30087,142753,58932,169773,39056,69633,216696,37286,18719,15633,43495,207782,80638,155987,196334,216069,149214,75657,115121,32598,122866,180532,79715,183430,105515,71367,131195,141552,129445,114755,21087,156771,10449,113253,192553,84494,158259,65632,6780,23940,60011,176763,219141,150784,153910,5438,200477,176234,215330,70650,151059,10546,200044,198251,45925,123338,136043,170789,115927,72917,47576,153440,114264,166405,213449,128902,23918,123210,91215,107046,87374,84163,175671,27420,42159,86456,107910,181842,129884,75554,173692,104976,213271,199360,144204,9619,119229,23084,82448,66164,84744,30388,202518,105952,134898,216249,301,172401,142237,121104,108330,14209,49173,135905,94838,163219,198297,130676,163947,115878,199226,13529,53361,70007,143974,34344,75505,114849,183042,127064,31831,7567,165156,159612,209535,62730,186065,41517,81461,144399,3584,144769,106952,24434,29743,120965,30793,169793,218141,40362,130646,187853,76509,133397,184919,71676,108918,218817,126484,123789,63892,119743,144510,37917,100554,49967,157123,133232,195636,35784,18300,72416,202552,207095,108774,89096,206495,133100,70037,215102,672,74144,190329,78264,219539,153862,152023,172999,128356,38953,117059,141185,126967,90472,87147,144681,199982,142456,98883,119365,84352,49454,182845,62601,149893,100393,61226,203304,107682,11441,83408,195219,184871,212705,99937,101208,173982,207721,215154,170922,39873,129847,9704,33094,131672,154712,87030,26309,115423,26138,137874,119780,4023,189384,37789,107473,171646,40464,19615,123074,204872,69474,88751,163383,181588,104197,170350,21054,86131,176778,139885,99617,83010,164311,188407,199072,127912,141824,9410,161863,17936,61543,165455,179411,75334,59634,195760,23692,113763,25806,199326,166133,184520,26347,116307,43611,181928,78503,7588,12056,85032,208704,14711,76904,93969,98262,112901,38160,64014,139242,108016,148354,178729,207754,47200,44560,45893,20701,159774,162453,179079,63132,130460,152871,37517,60866,120887,167222,66578,120473,66932,174808,2463,210929,121963,75400,177631,143287,41412,19362,115796,90587,154028,78421,167493,111230,180961,65561,119756,79199,52223,100845,126670,27958,62182,99972,149926,94098,150683,77561,183311,77392,5751,217589,172550,103758,71953,2122,179778,204029,195210,12856,158965,195342,130214,218323,75024,203355,209416,60324,159138,210221,92358,57410,166885,49841,162762,65700,177671,198070,188987,201187,172800,178482,219765,35855,164691,25470,164484,169434,10334,80984,206729,115559,8746,151931,95646,191983,34448,23627,77361,85647,195947,108921,46042,52637,34644,206179,141402,95907,139159,131542,71441,217703,43134,67363,216187,126313,211416,77369,195706,88792,210824,30191,122770,19739,36898,197696,59143,177294,189849,176791,104179,210917,146096,95885,23672,207272,103435,3648,69326,140659,32398,119219,96622,176377,196348,176250,68724,153238,99883,215591,167391,97379,28402,176904,61295,123594,6560,162410,147169,85984,159929,58030,169778,16571,166567,215968,121704,183783,79217,170194,107032,30585,36641,11887,9754,79787,129138,30440,25478,61551,140914,35563,101881,118919,97258,175763,194809,182579,141608,109871,153967,194581,190473,40507,108759,171683,25957,218547,128279,161389,106985,73332,8576,180952,97132,56955,116230,116574,61895,95075,26045,179746,96294,142728,169037,94024,15872,104369,72320,49757,32023,216056,119201,24031,173740,55602,168218,167639,156538,5598,13258,206253,87427,63850,33010,206301,148000,53793,17985,217531,200581,145282,54305,153716,56610,213136,61079,86129,202996,38680,206738,156233,743,72685,37929,34076,141614,65707,209325,45743,28816,173292,6758,34551,3895,93911,207089,206682,62372,154364,84874,4136,45677,69564,13736,33228,194436,110352,137910,29792,195471,16661,124845,117512,203952,94906,134542,155625,28587,118490,132078,58077,161246,11367,190642,205318,172727,136695,79070,120073,168647,8165,15945,125562,105281,111745,179856,28301,128521,186751,116277,19265,21178,84439,159461,38884,17218,33080,72093,163661,165957,153449,143748,186686,85245,99851,156602,139082,77305,118938,132527,158709,112774,25999,207905,183967,29992,171629,170633,201578,144529,188963,56367,112740,118372,59898,108478,88848,216902,118882,104524,109049,191263,24926,81932,111873,81926,174354,66817,98120,40013,180046,124326,17598,23914,218044,22377,24439,113213,25313,195189,193670,97687,212800,34108,145849,15723,153738,147216,147241,86414,175639,32042,107691,92693,174415,196682,110870,189021,13486,157393,81906,92181,27158,151497,133015,55768,40561,172159,126403,58784,28368,91780,145821,96353,30116,199912,73024,150496,59286,211608,112489,147053,201177,214545,162981,160844,176941,168479,145945,6882,159644,172446,217438,181789,70109,7589,47294,71636,184208,199861,80998,89082,92018,1600,47554,37552,148457,172318,105063,18241,37191,194240,165982,8036,24053,195588,139062,56395,98617,110056,118920,89363,41263,97007,81693,189162,85035,203642,148791,156296,22270,13791,40784,41264,197770,161962,13044,176676,21386,71330,116159,173523,111152,127313,159141,115879,108575,71609,56510,197903,134308,29836,90484,112699,29069,74251,177790,177861,155486,127567,138132,174884,80850,11905,57807,103615,157109,14360,2015,53588,79015,55373,201156,35975,200320,66982,75876,86015,31453,2026,130386,112893,46057,33865,99669,139613,184463,60322,131140,146466,218458,175124,211340,92842,18894,104089,194521,11089,114352,4912,120521,20411,82066,150931,112534,150313,106017,185990,167809,109495,75894,30192,106176,5467,85136,45055,210061,162037,120611,218110,180777,94857,23102,114481,27751,63783,114209,154369,104975,137353,20616,103747,81336,195508,89001,58083,87311,204241,44074,20022,133820,165996,29891,139094,133680,50830,203483,91586,92038,190678,214682,169199,213657,38716,61936,26949,38389,67100,51713,45481,157915,40073,199285,199007,171837,19753,46437,202597,164613,129529,121622,197773,147779,199197,151904,117677,178555,147978,168556,166539,163714,84727,164421,82873,39572,115281,17900,122924,115922,165513,79382,208753,72004,155037,14725,92259,132995,19448,81371,121043,118466,9858,14913,96662,218024,75255,20114,168995,23670,161220,155734,132924,100789,121297,152129,145310,207570,193703,10753,117894,212288,149856,170875,49345,84286,158598,124683,131430,39547,16219,118101,27132,212590,219513,90030,187253,22106,132264,169378,45235,10260,64944,185498,115191,137381,137035,159856,41614,95393,97740,74158,101558,107156,203021,88700,177968,156732,13727,206891,165383,169687,87726,17278,86426,24774,146266,177833,101335,74487,66059,189722,172114,157994,39419,87388,21365,152761,173961,37074,91703,81325,171980,183102,173574,102749,33803,155986,185816,22208,81155,98744,168687,207529,86108,213171,33500,175666,184242,217306,193665,179677,137990,24498,169351,175756,110993,133297,94851,4779,95052,80967,47254,56149,89583,200843,127473,50333,59386,36061,40617,20228,77540,206600,26513,133087,202080,141497,218921,193631,138027,9895,173520,88280,200882,162013,34353,156564,98754,40436,42532,47407,217932,181625,202138,163771,212085,43395,189053,114733,11453,107704,95498,177304,188635,62267,116444,160063,74692,197234,196224,205468,117683,61342,139987,62696,81125,211975,205617,122949,218919,50860,96427,177814,149532,115499,121026,201703,73864,203980,100105,131833,138270,109749,42482,214143,52349,210045,158508,36280,65536,13678,168081,125983,97311,145369,26893,186489,135242,32683,74727,21987,200046,42024,71510,117786,134572,38626,217853,114680,24042,58205,214093,118260,91641,6991,108762,13960,103928,131598,98031,61807,101231,21392,98281,58861,7389,83541,25974,219985,187355,170383,87894,207455,38285,36070,10040,75203,217517,123851,182471,217870,36383,198925,184952,122974,117598,161309,1376,20768,4981,67088,56532,218880,39195,143865,190640,159614,138352,110953,29146,203468,62295,78944,31941,12517,62505,110054,167497,208115,163266,119829,16464,29060,219332,131027,156431,140523,4308,165663,135102,150745,63422,188362,37637,76355,112522,410,161168,92488,63062,149596,109200,49103,174160,175168,22443,137228,94530,17741,137903,142627,132992,206153,133070,145267,5330,102355,123243,11412,166637,101418,97858,66492,195904,4257,93610,206705,17710,90387,63209,192136,172255,164693,27174,202244,152542,192724,198358,87167,34972,11622,92353,143134,88748,213256,168505,45898,217680,204247,179485,121798,182292,125444,165605,47783,212729,35350,16541,187706,203473,119977,9695,76924,52845,11483,207384,52289,10985,82731,81284,135104,104392,213753,142352,124098,66404,9599,197470,169581,50096,75002,93578,204692,12531,171896,87012,71622,32630,209542,96474,114544,104863,194853,115274,101356,88876,48973,130076,181421,130422,96297,209005,136125,69919,210016,21656,190600,111280,20690,112881,157929,215571,47305,142992,77009,84203,26895,20989,85962,173256,77500,80346,116894,157973,188131,111703,43684,180817,116367,92155,117212,11103,190530,92421,161179,114017,71973,167633,207481,15036,19671,175946,167120,106485,95239,134502,210002,196581,178103,41939,8163,37428,159271,177756,204816,114864,9122,33090,17624,61864,204083,169125,96030,94935,100362,148699,8489,158621,40222,178012,117944,97246,97521,116376,200038,20155,150450,36095,138794,96193,104342,82414,170278,73056,65434,29739,6793,192852,48771,130821,135723,101446,147239,30870,68623,203152,68233,184559,116987,56248,160399,74862,181922,128765,52475,32148,35580,19412,118500,45260,187005,116683,23013,212471,178763,83778,175070,91075,186001,17023,144178,142126,76116,78631,41289,186596,186007,183558,167169,45636,208062,94763,133367,58783,31814,52685,207635,36405,62058,207191,129504,6890,94606,145233,150019,96707,122527,210675,144610,34018,160418,22600,17235,81078,104375,187939,188447,125479,137807,107687,201631,107330,215833,150706,19333,32853,83106,168385,19441,117958,122115,178286,135622,90312,33642,217331,204541,144520,167718,154084,47700,201269,33825,113408,131822,14526,217944,32548,135817,40112,79724,43163,42458,84583,185947,59105,90707,136025,74461,20672,65694,51459,166476,71945,32812,163845,79383,161068,139848,24491,131746,168031,44206,155139,152212,40421,44879,172682,163703,188909,158633,88504,147749,10798,216330,7437,21292,11917,168153,202118,151197,69377,170689,55262,201053,149982,109205,162033,167507,7939,130557,164407,143019,75935,168272,79192,126601,64193,210944,211641,179566,106454,77970,118893,19129,180515,15705,41439,115239,108952,126943,121750,53483,89179,37661,81941,188314,83706,192466,90536,104521,34279,199403,97134,134998,147233,27835,83654,63370,122282,32118,70116,117858,64990,36930,25378,13265,76077,100702,161317,109611,65082,41864,213248,85835,151437,189039,81976,49763,200047,41770,130617,134932,122372,130746,80882,130382,6075,23605,103086,132499,119860,63105,56398,152920,92509,12755,13248,73755,129733,156635,171510,176297,123361,74907,140690,2143,28179,112965,35110,69329,190711,95886,200240,105710,95928,11860,105010,13410,149495,147276,51027,95033,75655,19275,101279,132187,118057,200383,144141,73327,216350,163379,178362,160181,31145,33728,25344,103259,97831,208377,88889,146232,95841,197909,37818,52181,157932,133463,105221,131092,10552,11876,10201,35923,187063,87304,210754,124189,136139,119796,39043,158927,135136,36642,85974,160712,83500,103010,161611,193808,78442,155641,88178,132989,217131,133618,139549,128410,185791,147567,78588,124490,213843,4361,96562,86832,176634,28732,109193,153014,80655,208864,190139,180366,164956,6985,156476,124410,69607,171859,205233,203466,151643,151423,59649,189110,13462,152967,125946,44698,137424,164895,162460,203058,99664,38738,215393,178448,63490,8276,150095,183637,28808,50012,4964,115615,82232,109780,39696,108255,181010,53448,107577,131548,203154,160295,123628,192895,190684,16326,185014,36197,135969,54353,147033,85209,173598,125402,137750,98782,82248,45398,120498,139756,89738,143202,92874,177174,202372,188986,178453,168639,210672,181747,69300,159927,126924,50400,64623,73139,146299,78249,58947,78017,202158,75760,184785,54336,180942,184808,128189,83110,125777,91436,146928,208510,189179,71685,75442,31948,150333,177372,142394,99534,103424,214532,90474,202654,210925,38397,76140,11027,187207,20717,90797,115966,171959,67225,195910,125587,56095,52970,217042,141184,71075,147329,182497,71203,35985,28633,161396,194220,153834,62726,63558,13307,175592,139209,59158,167233,61036,13760,26266,108344,86584,188051,123837,26357,178570,202051,36036,1358,144271,41403,106668,171151,124701,125160,170534,52286,198157,75299,84192,74819,169318,15509,201516,23451,170996,150505,60875,140286,193642,189560,9777,45892,109546,219573,46081,9526,219205,104069,206559,129861,48840,196346,75882,9821,2429,78221,148887,158015,28128,87829,74596,119104,168116,142422,137446,129476,35179,132176,122688,71522,50544,212025,29529,86686,42596,191658,120301,169944,67440,188376,48810,3692,193099,88371,207294,77338,148840,176790,198517,50636,46004,160158,167438,106296,216790,112119,135077,85964,22737,105076,175441,25004,48439,36860,125161,84830,65024,1790,68363,100551,61740,117109,197592,69990,86409,79153,152688,189499,150176,2997,68564,171359,94254,181605,61945,16285,175056,30991,122000,80427,41868,106277,179850,131732,184433,201809,81521,180955,30779,167434,77342,96349,161120,57894,57443,35014,125352,40156,119269,196014,158790,97943,109011,184009,143978,123424,198534,140907,210238,174088,216427,57255,199716,64871,178347,197850,156189,206404,21461,137758,117112,138457,184488,94839,20474,147893,29414,16189,217258,143562,132510,52994,150130,140682,39282,43132,86014,136404,115825,30471,178335,53844,187935,152868,128143,23839,133791,116780,211993,14584,118864,34588,134544,108901,119746,147825,15135,146485,121209,176362,211528,80784,189603,5706,103795,66704,214028,790,57121,151600,19144,11872,111069,90316,183601,16714,141862,15795,18092,123750,8319,75259,107182,47181,201581,35501,200757,168171,191178,169313,110195,98162,100264,117581,98958,98473,21037,178973,173484,141374,34862,171320,91165,31098,46785,140839,103037,138632,33383,190803,58376,218843,910,198227,5959,78214,121373,176553,188506,142723,111182,139443,99425,215838,60205,64883,120750,90743,40664,72259,49409,190060,200668,29581,8445,212199,173066,109901,161198,200816,4101,63089,54023,17639,26449,155698,8801,117031,156705,176166,184362,12760,64155,194128,11596,105451,115071,61430,141494,197757,203440,14796,36735,132093,75828,61384,213853,191762,151023,83598,151381,156740,202516,176300,214534,84060,62102,79105,37554,173070,136642,57945,108412,78774,72129,145940,155101,192731,45910,164100,177930,111988,145712,129924,12290,90286,168641,175988,99917,205993,137484,83554,182533,109137,107022,39116,78613,98686,48202,197765,141014,124130,63168,59072,78836,185487,37840,211229,121379,15105,147426,108180,109185,146089,138934,35151,101821,63663,66845,53303,86576,169530,20727,117908,97275,24238,140413,49933,13510,70375,98872,176599,158589,158052,10367,19079,49325,210934,199674,154031,189872,176016,146499,56889,125670,54759,87284,79478,4019,55524,49813,194373,30753,195789,197743,125546,63594,182354,158440,53611,103974,62739,144724,84309,203296,74202,99788,122224,139893,170058,94177,80851,68566,94278,134431,130358,122160,25849,210453,189376,199656,83947,53248,97178,82030,108509,11984,147462,58005,194314,38252,4251,68360,145047,153100,151818,188746,109465,77382,39953,51430,86472,60298,99572,149366,218461,64416,130888,144347,171601,179528,88490,67456,200110,215375,128089,189676,168216,193880,128606,120803,44101,192091,208398,92320,44371,36689,188825,143273,128009,48247,142117,168246,15396,137348,8866,219956,19485,214432,12717,200366,1673,108106,36042,165334,60668,17817,185380,39591,2396,57336,132574,119540,97832,15846,161961,167176,174785,161616,126669,172827,127888,4123,1733,139516,107829,3106,4384,138783,189318,72021,140440,75153,4493,131652,213164,182939,176693,112872,211332,46993,28053,25248,137464,39032,50313,162471,138007,66169,213703,92914,70077,208387,104044,20714,97793,106437,120323,148111,63755,182776,59207,78636,179317,217071,21133,171372,170365,199120,8305,24510,106241,99421,144887,124781,14710,166919,2477,184246,44927,21675,131034,113746,169058,205619,86840,148335,24998,138453,11154,60231,55421,181479,148579,124649,71065,12205,19489,179418,73498,147720,172519,8625,47016,82440,4104,153887,37798,197198,215139,186833,215194,104410,20168,78453,42585,147838,63038,148555,218798,102015,177480,141705,86960,100622,198078,193919,36687,207224,180698,189835,20559,131279,195416,90664,14076,25607,114580,60855,219853,19927,89284,158641,201590,160921,156135,103993,202782,85557,7766,166231,71651,205402,118109,128608,59805,93273,144975,98415,113206,48690,178177,153583,173191,99633,22464,202419,162166,77431,209205,64605,186937,19269,21505,70247,40428,99614,186561,205492,166173,40442,193896,102137,83015,94555,27931,23991,1361,80810,116584,94350,199582,70717,26799,34735,22884,49123,112978,117661,146010,145421,134333,106966,27077,6884,23493,92776,145285,24396,156480,156773,205635,85047,100987,3235,76506,108477,101659,204254,22203,190013,146679,63633,149935,136604,44455,179651,99806,44433,36396,70495,78939,70361,129269,38255,16585,43832,113947,72341,110389,78578,126938,205171,20164,94666,65981,64642,188979,163848,129971,155662,161656,51261,120025,28042,35517,79864,1774,103728,87088,162579,99585,210783,86387,115401,87749,112883,213772,215884,170665,155816,35684,78757,158128,182266,52652,125569,46566,104396,83660,76401,192642,182179,165911,128714,150963,204989,63550,85429,98580,73479,214039,206010,103005,95686,29855,147810,52376,155248,143033,47166,201226,144205,7055,190634,121011,185694,54958,114938,209108,76370,217915,181622,18129,214686,208576,208416,107254,176417,130889,36484,166025,79385,63444,66238,172664,40407,187494,110920,206851,98825,19389,117720,156744,125628,152631,105068,140081,132805,181327,109655,142513,9637,207282,94248,183929,212938,140746,155666,167295,22277,28670,201576,65072,172386,174245,93063,43604,169607,160609,11612,147644,169091,177595,169539,104757,197311,86926,208727,112841,27772,2765,67826,58196,133870,195783,135999,146274,152179,180267,150794,57942,116858,98073,102533,121389,202464,178141,154295,180993,131194,39516,90431,6438,126555,27406,77445,109045,22506,30426,217924,191404,37139,91312,81760,89932,119362,205591,54229,136607,126596,91304,124823,25591,114862,189266,182652,118606,83600,17609,78712,11713,212385,186326,30168,5920,89935,169861,28668,177508,209728,43084,194114,63940,135264,45731,144679,41978,86766,146761,111770,121124,60779,209126,106180,165087,48352,48860,167897,172309,113281,103699,7652,193301,160984,51727,117853,155242,112604,101930,1166,184742,56307,53945,72966,197217,184486,212003,208668,16400,151348,26673,209584,140742,49010,95890,85454,51661,119944,29872,68774,175820,128315,138337,167538,82157,156492,101879,160250,102903,153879,29945,91206,92270,161837,45227,213316,176577,185070,209875,77973,161083,154709,22217,176189,35113,82099,30896,62540,30657,47984,97859,181780,37056,133909,101833,109638,156349,36141,150681,100587,110897,48694,166592,140870,181779,169151,45412,145467,43742,75716,36777,49046,82465,118227,162374,94119,2512,127037,35791,51151,100648,147215,132497,171198,129753,107252,179489,108955,186159,116353,128109,43844,21671,148215,8089,208635,199506,57810,76569,8570,72051,58941,140930,75673,44094,204666,119884,148281,194646,200781,129843,144107,133904,29677,150133,29938,70225,203191,142372,215818,96170,142093,215642,198386,10907,199230,188476,115701,143223,57289,111451,26847,192935,217598,171356,196960,64888,78042,8449,68895,91377,214100,22802,115099,30772,203327,212831,62113,55018,212520,193360,154120,180652,91915,186319,159976,165877,112241,43511,161520,36499,205743,54164,213261,54248,211141,15560,149441,91964,139534,73684,156462,141234,44139,84843,184563,76712,75846,150139,70082,134850,178074,213694,24962,35543,196632,205964,107897,15541,72698,171939,33262,183874,34784,65451,38336,186487,84435,217181,64894,199615,177270,102944,128308,37269,151368,165780,70442,164577,108586,98349,118486,19796,166101,208299,201558,24404,105604,135164,196556,72935,182112,96749,119195,127937,85796,152556,648,202934,192079,24416,191609,120179,166419,174101,182790,93905,16601,208582,140096,104258,57102,112628,216228,55535,129686,70395,84637,217145,74342,88540,143474,151501,34234,148410,127228,206649,90109,178926,199820,12690,12043,26102,164149,205491,219336,120476,4268,32568,41414,115680,119673,177,112145,53036,180510,34580,170392,79132,41674,71832,24468,171596,94485,65638,21606,97476,174196,170108,43391,13575,103898,163998,80087,190313,183354,197979,61379,112520,171831,23907,184470,24951,284,55947,124992,20431,34786,155629,59294,136169,178253,116671,2434,2113,182346,90065,214200,31612,110707,181936,34648,125226,18567,60017,100419,23417,191565,211276,26850,27053,82143,96354,78447,35947,100328,202758,213219,197435,34903,168185,177507,37448,17933,139069,147663,2221,159991,170398,43022,115369,92061,189381,55851,164647,196154,39681,107927,161500,179979,115968,56949,22724,26336,36703,197249,32284,154553,191367,100671,92152,112561,82445,207376,36655,64955,72605,168983,22085,145069,157548,157288,189052,159276,74690,204202,181004,7031,172597,79075,53856,135724,159102,133568,49452,196042,102695,77474,169389,14304,208084,62741,129731,101157,29555,63198,130962,167269,155695,18615,138487,3163,94659,82961,34598,101407,219931,148997,110326,95790,143224,179029,45290,199944,123473,201947,19134,4375,153973,17703,3615,68730,56555,10440,15851,206907,104326,132804,75101,164776,186811,131346,201093,108903,183975,105454,21607,167913,140657,140943,161493,40024,72736,21914,81240,20715,134035,53267,211828,40731,140611,85655,102427,153730,166276,200074,218766,170351,178534,168599,17331,81309,183153,114684,191354,62557,15179,64088,113897,30405,118848,160509,159601,14425,174608,194430,172504,194159,46044,31311,3320,185571,35969,183893,2975,42947,130131,90886,137102,135586,205255,215615,188944,67960,43768,97483,33274,196549,204766,70454,192134,30947,202225,7695,87821,211578,112544,71298,137325,16743,68930,186731,151165,164503,20203,130062,119327,133898,94404,14533,130969,148605,43746,96387,41237,66877,202723,26975,150447,176716,189472,30234,59874,194191,133385,673,11544,3221,64107,12053,123670,95414,100395,39343,46913,9270,144930,210531,187176,171888,109935,58855,84358,65080,109119,188819,83869,71108,216530,20417,149824,97729,31300,131607,176508,13807,47014,58506,210852,135207,12039,105100,18068,121921,74005,202775,81447,85799,22645,144989,119733,2089,96709,52604,76439,147798,205653,79668,193875,163241,63591,121615,96540,155108,129331,202535,203393,51228,192010,143277,199355,208659,64258,39470,1671,107488,6267,61078,141962,90812,166708,181786,2515,87822,383,198948,172308,98638,193483,190635,27249,53790,138887,62434,110052,128982,15701,37322,186555,73354,24386,11508,60679,136251,108280,183722,97646,120182,196057,22256,152245,25054,132494,34999,167042,214948,103376,19616,148850,16327,114028,173655,34264,62068,76461,68417,82577,209929,183691,218077,192768,85524,83376,118687,71291,61239,19578,53139,35680,204455,201889,153277,28566,40739,27697,43089,117851,122173,82321,106564,31650,140489,93922,201939,54430,118467,80821,121463,69470,31434,23703,41317,211706,179845,79748,216632,184734,182682,158152,10572,56366,219132,85929,38882,23910,186425,65069,93255,103873,134600,12832,177519,78339,68201,214981,45578,8147,106008,118541,145739,196104,144210,65497,24831,121121,26714,210736,208092,36645,31838,3134,16166,217677,204723,58271,34266,51939,215921,97580,179905,165228,169000,22030,152652,153647,67848,213056,19657,5714,16900,51839,169341,116242,212152,33761,24111,217410,212044,86913,32119,11192,120847,13274,44037,150440,113549,215399,189905,103546,130056,7728,100330,179100,111847,45300,92784,56286,49060,72084,73362,116919,38962,9083,162313,161736,64431,169680,77360,130781,108001,144359,125575,15985,22995,73341,100405,109899,168467,208980,137839,65312,165348,210283,142516,208364,5119,100269,204974,94139,125746,142742,204848,127055,90186,148282,131699,84306,101811,70893,47137,6832,83591,156530,57414,7681,212517,204089,73644,15359,211983,123649,138787,93698,201023,153089,61017,41844,26249,64974,172193,63216,69868,139903,213984,207418,193215,14656,200291,57517,150914,209031,101822,93009,219997,45537,46163,62010,155325,83362,211343,193489,182878,94052,155074,7394,184405,183647,92191,148806,147746,37251,147615,49344,219335,212441,129034,142114,81016,46021,128564,9951,23660,14724,60980,157722,57359,5224,138005,125301,179,86962,160991,52812,211404,34207,89327,187940,46520,210350,84826,15612,5744,38721,153776,184731,37133,202995,29077,137965,95653,19013,97934,184690,173575,103703,153620,26484,88227,84568,35769,41288,191981,114522,207576,167579,127716,170178,83173,45666,185383,147318,181812,160880,206305,93539,58520,174465,154389,213504,45967,99191,80431,192412,180485,77205,46585,191058,471,184027,150165,102655,202068,148762,8488,47731,157912,83418,212023,160269,57906,167679,148255,27291,130071,36868,86813,195387,20396,62389,90914,83759,44010,166772,23258,185851,175642,167536,191177,88014,117081,2987,69523,54967,65415,181711,17771,91876,67281,206312,28320,188852,335,12621,101171,114876,192042,110009,108532,129258,99277,92073,143136,98395,26588,217266,125695,206760,151533,170647,198768,180134,59135,42639,118302,19322,209504,8758,77308,5263,83549,68453,27453,19451,89602,44592,98424,42180,191885,19661,143404,24406,88591,155721,160863,126546,184681,8124,113068,171346,43253,159204,113823,41183,14039,26544,86794,53883,49803,107290,182765,145257,189277,191358,141590,69537,173171,73799,77961,62271,25143,13011,102996,149886,144090,128699,40262,14288,94589,1000,112094,22971,77214,215979,173544,163863,157342,125915,52343,25424,6746,54251,44782,167297,76496,21479,124143,30279,80832,205750,104193,123455,128213,173287,69469,23890,168261,142639,101587,48465,97260,99993,96836,48911,116857,11696,69021,115275,121912,69048,59308,219435,70401,147913,16158,39617,197193,182575,174972,25350,22467,175149,90122,211859,141937,108524,198300,154593,59641,145123,15986,101477,214540,136991,109741,140381,180110,124472,61772,124849,78342,20697,103495,187250,9103,131622,149741,135480,209833,149811,177611,162600,38652,31531,198492,117713,45319,43821,55545,211933,50521,33078,11187,110985,20543,179077,114135,52931,215630,164423,38765,218656,155988,67603,83820,190605,211640,17566,22813,101737,145847,104164,144717,85492,72404,136233,120103,3241,206189,183608,163330,154029,136986,204914,109994,29496,39227,38849,148529,154123,152745,198804,26267,27691,147504,204352,26474,205978,74997,90143,107712,69241,100412,171648,127293,149952,159635,124006,45141,72198,105637,37710,162069,159555,178471,181839,105308,10166,103754,86899,181014,62279,12921,195493,124575,70407,97633,5303,88569,214272,208051,79444,81513,73491,128875,182500,179529,25747,59757,35228,78769,194962,116158,202907,216484,84113,70626,191050,109254,158591,169228,23406,49266,115793,55414,215998,106737,194234,126998,199096,135397,98018,215224,14734,135074,207484,41673,17201,81559,186009,132853,105744,35559,204607,137438,149074,45544,50531,218347,52619,219408,15106,65402,8669,119655,13253,94470,184715,52195,71821,96561,215877,121692,132415,104056,166568,32472,178565,8011,97639,128132,155724,116943,46552,124408,153847,144528,91687,90242,42864,219585,67806,196335,217807,180578,23695,74729,7195,13015,42705,150014,204427,56182,58025,169706,59011,177653,54193,70538,170809,107349,134553,5256,204551,201214,178834,2332,124281,34685,169704,45403,157585,209941,2218,58659,66162,159473,211048,80074,186840,185956,168490,72188,111098,98690,91412,57202,122342,79205,177522,136482,216088,163474,104111,153219,26820,1926,134090,173632,98102,147182,169696,155653,159658,74102,78243,179414,28338,126284,16734,88154,72478,167928,85232,72583,68458,171997,186408,170778,78326,49377,39303,135605,200664,62993,14883,211138,210334,212344,157295,105931,173154,83777,178456,36681,6672,171809,130383,77490,68007,109724,106063,101783,193129,9258,152836,185616,149957,213913,49785,89959,190460,58291,179914,140151,167330,205756,124711,176126,93977,78694,172389,146341,47483,77013,25586,123412,33539,194227,67804,190103,211902,146546,199994,47990,178182,87984,171556,23409,58262,93135,58544,195090,80500,110224,206593,86852,97434,68304,153503,77083,121572,30875,123348,14071,166303,152980,152125,162700,18668,50998,131197,31409,178440,100052,138893,78212,108293,14172,39525,36009,52506,89828,107267,149265,119949,37545,82938,187608,47279,204194,21184,127816,87196,165421,46353,82159,166330,15207,1235,118503,71703,54888,199542,202779,44198,151068,200404,41532,128339,205838,200165,204407,23079,33922,161082,113694,166064,112567,105747,112466,125147,100484,1223,9889,139927,52275,191537,97806,3690,84466,198403,218425,139207,49958,4696,177405,1081,164282,192309,65227,59314,181803,90269,81813,33798,26924,100545,131946,154316,80246,43783,17114,215965,78171,76238,119518,190796,136390,156334,137394,88561,113682,180035,35072,89845,128497,94175,196638,49394,43223,212863,105746,5013,59878,58804,190622,34394,55362,192433,154584,132286,44298,32286,96099,186413,170535,207651,9794,98354,167856,67053,202178,140177,163351,12736,155693,12572,173531,28341,171328,219183,5175,12644,180813,30353,109255,116474,99681,32399,143466,66253,124360,183990,40141,53869,182240,179762,165232,2439,78935,110328,173451,26471,173665,136395,70772,163262,158763,35956,109792,217750,27502,134108,188420,161925,30863,74109,29887,205720,27856,130479,52074,159646,52363,203937,68942,217972,137139,52867,93827,184401,107415,77424,42040,10329,143691,130116,54445,180408,126003,86665,61474,717,3342,176064,23331,28948,218951,150510,175380,129526,39097,23614,204348,134140,19673,189677,189681,26552,67080,168909,60597,119734,77600,69238,121707,13128,26164,213169,45797,10340,76170,94428,210133,177410,81974,111672,193603,30382,24903,198758,12029,2535,36279,173894,169659,86156,93723,114744,163002,70335,192169,24366,97680,213637,89842,47938,29840,105061,104923,119678,71259,219075,101113,180853,125369,209688,110556,170091,208905,43720,29637,34737,182496,190424,15373,41784,27187,109059,154861,126530,210308,149145,177836,114861,47516,160254,99549,204253,94004,160263,7554,179428,219817,190651,34250,126636,127958,29605,107815,115523,12051,16603,68580,171697,82375,3240,184698,177714,137577,201203,194936,189732,148114,58716,175595,88683,136078,137128,180418,183652,162067,24967,113468,173464,190487,65162,126334,90636,173987,177286,168140,100893,38968,219947,216645,146366,160671,3052,168663,41137,132531,123021,129980,43359,19625,196982,128882,84052,63148,86742,13199,132888,58381,144309,99257,105716,61603,20664,120348,116441,150217,117466,23073,130667,117196,82853,31118,195359,171277,4658,26907,105759,108092,9393,146048,143942,44,216507,23827,207561,158388,162284,167126,79848,133257,147412,54224,171598,176022,119592,89352,95248,14227,58436,119133,87024,146840,201173,162465,157273,124029,187867,172272,95235,187672,118628,29890,25542,199051,172729,58129,575,89484,93097,217549,166110,209692,76541,143573,136835,98013,184050,27871,12000,43032,129085,36817,215440,187919,194644,189152,107272,28039,66660,160131,51359,49163,31652,102995,56515,119811,216151,204844,49959,187901,88587,28250,198482,108650,12189,166000,219536,176197,154353,32301,202413,118097,119901,172610,154485,132172,35294,130764,1137,139031,192994,12732,113458,219890,153385,126597,133256,204755,46452,181355,151664,45521,191494,213711,33654,27146,101556,172365,157429,158901,84369,132149,100837,109750,181538,199851,65074,71721,104299,89878,77115,118710,210972,34962,36223,108198,157948,183504,196195,134294,78358,183892,143604,82289,186613,193152,145118,165827,56975,52661,54974,163621,187271,76173,179991,90516,176258,34779,186226,170416,46612,132092,142210,182009,166953,84510,186214,30252,199432,92337,147398,125926,152779,149589,174891,186154,109439,176748,196049,139959,74954,110726,211892,3347,137030,27910,6624,99842,38096,13910,175703,14838,52403,196656,205922,42188,75134,179146,65683,38385,16261,181035,205614,53477,197550,141865,8468,94026,213392,118644,26864,161509,177085,193853,147294,58800,194619,169638,147133,101944,62284,133935,205601,208885,209406,182485,75402,181649,174193,12315,100029,101595,206598,179607,81960,144736,14103,3398,194948,65887,53238,162676,204714,116146,217575,59126,166143,184714,186003,169772,96401,197946,154241,143291,193580,158174,51503,203462,190031,50724,77116,117654,46511,178954,19521,49014,45559,192696,134701,31287,98880,10652,112330,73167,145908,69268,34096,42186,153088,66850,1479,86655,122245,186040,40581,10470,40292,84575,156468,13331,202810,163936,161378,173510,78945,128594,207069,150123,144221,93869,19105,205529,185514,83324,137729,57661,48200,136473,17872,206112,132419,41778,109033,142365,140780,105753,24102,91593,57985,55929,173889,87001,86787,94580,76958,57264,197304,162078,139067,124480,147150,215924,197167,170210,2307,30073,175358,90703,116828,63005,168365,160878,162825,63917,10453,84131,99974,208720,29720,101155,67263,141619,189785,74141,6476,135493,215510,201980,98293,134823,133361,116800,127612,11495,77708,170794,50465,83968,134714,21586,25154,43881,202720,141026,139776,1921,164346,17163,55255,175724,165023,202789,55882,179674,111203,219064,27243,53891,140969,217251,184219,194347,203215,114094,169492,214107,18905,195792,179842,40305,6226,120383,187922,86957,9931,22943,19201,14226,46659,208198,66036,146298,18465,150722,215382,59675,67803,106858,101720,118259,165207,106023,113514,82943,4427,102105,176501,31013,140203,1546,168145,188818,162096,16577,198170,195546,153513,183732,19028,93768,131815,27822,75978,206278,179798,78997,159007,22676,73798,200612,114734,99985,103576,215514,173204,8079,124253,40979,139845,57636,35751,195757,102496,208868,209055,143058,77326,166211,37768,78133,177547,208408,191188,169853,97813,2922,145103,145508,37060,32305,10950,1318,150134,156409,103306,140811,143530,22315,79872,55064,196161,194361,190300,90138,55505,108174,169327,134362,37343,42193,48232,57878,161712,218824,66254,50941,30047,47541,175892,209626,148712,14257,196881,143878,120917,168997,144091,19174,76192,17658,65716,25743,51751,190208,152199,127714,86730,93621,33610,175976,171828,61481,26226,74746,162402,18299,83305,127365,118134,86073,161570,172863,80130,181785,153136,39949,148256,148599,95549,211733,82924,112278,44207,1266,82107,64533,205956,58071,186698,193938,167991,114327,72529,94237,35207,178147,192144,86917,125817,121332,116747,94041,203100,80077,127656,144617,28529,46077,169227,159023,21488,73561,216647,36700,141195,51493,69005,192294,164264,20044,190690,19765,4883,132953,165577,8177,213029,152539,151755,106203,201641,154674,55773,9609,144597,68319,142467,141473,123759,171418,203213,38963,95897,102891,58538,154560,192211,76283,35096,121818,131366,133801,23738,104014,95385,131640,2776,173717,62621,211429,160965,179348,43474,198497,134777,36797,116974,41739,46961,148467,148324,172861,38325,192916,191890,126747,94056,10595,57362,127832,60751,16478,69299,96845,61346,11221,175265,54204,135760,207476,94699,101736,123122,119182,12173,10460,170006,84953,27059,184480,136674,166085,217487,71695,218856,194487,69026,151684,142698,149464,45573,98739,171357,99226,95760,154056,171268,219013,20404,135440,65951,98652,58431,135363,106389,92510,214233,90448,127627,207790,126382,158390,1729,131553,36429,115316,43412,60290,20208,134883,168683,73612,55447,39384,48996,41250,94954,157816,188036,214159,31019,173305,60383,108688,212644,86566,188810,29711,126226,177007,125518,125572,55234,154872,42269,169675,107714,5884,63876,11310,33524,150336,155280,42239,35317,165030,134386,10273,150016,38499,43605,67701,47195,135434,106287,157117,147745,167148,4881,188625,190834,173472,215299,74305,22478,56778,117755,158997,123578,126266,44713,57504,217674,107680,157506,37411,208303,152737,215076,201192,47679,163371,172922,43967,189841,179723,89829,118874,156377,154821,18046,60244,99145,101046,34658,28461,4341,53044,137025,132214,179402,110046,216665,39206,158731,25244,46510,124746,161152,157237,10233,205486,139311,131944,197232,30806,24429,126416,172855,34164,158197,141370,216929,314,203127,177880,170192,113114,136850,110279,97092,11387,182882,138888,112448,60158,210318,126385,206977,201502,99938,91886,154415,210926,194351,199632,219464,182637,151026,25615,111539,66720,125079,200765,60951,21972,74256,156838,142535,107124,67582,49550,1433,185899,194942,206942,1864,151515,136980,30888,136252,98311,171585,53849,180293,210839,82811,107176,94966,27940,40574,111692,170806,153451,66488,68498,140536,111034,148024,95851,160400,72236,178416,104190,197911,53397,108325,144730,205465,146161,126476,46850,97049,146563,145466,126814,66526,80387,151427,90270,202950,184941,104367,192880,36193,31217,196035,60233,117072,205287,183320,42413,26287,56218,133259,105004,206066,88719,209939,144330,160963,182729,179454,185280,13645,42136,102001,166691,1386,210860,23878,215811,123645,157350,114131,23105,16798,68967,58135,97212,20981,159653,9897,189921,100748,105026,160947,172012,136578,75096,186543,156623,15063,202796,215018,45946,167848,129595,115073,2807,52876,172198,210740,135830,74974,51051,8810,205510,157159,133837,135217,183094,60731,161268,44428,176691,9453,94303,62719,171709,75,41273,22765,72335,204386,104709,92570,175547,103584,106286,153637,52510,4028,193180,184450,23636,144072,74578,190352,146568,152812,91251,91088,213782,72484,59322,158954,50993,119212,105894,150708,1219,68447,45884,106727,132707,27909,146441,18975,131902,77136,89120,30673,210486,134555,67552,141008,213161,207982,150704,65573,181978,116069,98849,147631,126780,201842,63772,198567,89040,203318,120284,178675,120775,153833,7169,126345,30479,67585,17708,100315,206741,57771,177448,183809,60274,209757,82766,108114,92562,135398,140417,2548,183715,218150,69035,65079,75311,24381,104577,176605,196290,215239,89651,19930,145372,46788,124926,82279,92677,57295,150395,70536,47915,195600,179072,49851,130809,193229,52341,61567,192788,63453,69574,54205,60085,53154,218874,65456,219240,162161,93330,40893,27967,31425,20575,177160,173912,129149,160911,1736,200342,141944,11079,126429,97306,162704,82684,129064,150606,72019,128661,200346,133990,73470,104301,58220,196652,151169,33618,14194,137818,209523,110933,125388,141759,8536,131720,91039,103210,179919,203398,45822,119851,31553,183040,135844,120604,82746,180672,2206,3098,176712,55247,99464,152307,26410,85311,94791,160636,68078,143802,68904,62164,70623,123560,160716,89940,95924,48211,130749,111134,55189,109862,183687,197183,97598,95703,105012,73303,118155,47156,149791,34377,150966,179833,59917,213435,26180,188957,210786,61814,145579,113303,99846,51396,38232,209725,193604,39203,123990,211201,2553,46296,110499,130966,35250,165321,135530,139980,188890,205870,127564,78652,65791,147361,148854,12506,104438,25945,181021,42819,15738,56483,213337,124921,165114,117926,195190,87728,68096,94624,94452,146892,146143,212340,67558,90214,5199,107940,31781,92036,178149,158748,52155,151218,150967,170358,170165,141148,45754,204553,74608,93089,85220,176473,133809,127793,203094,193075,117160,20554,99895,160868,75110,121569,165599,37971,43928,167924,211738,88444,110576,165484,160513,149851,106108,18702,20738,42467,90079,59825,82177,83067,171183,76377,186235,158698,168989,103666,71106,116627,204727,95201,196056,153857,136560,118040,43861,48204,6055,32778,190781,61506,173545,65967,195002,183683,196539,20988,55200,43754,104083,27003,25048,16722,143768,14110,6911,202095,104074,137358,200539,21823,27244,68800,162031,40704,22101,171660,102195,81099,60124,64853,193355,74494,116171,164228,33138,34413,138082,219812,44723,7489,9215,173945,93543,85249,124821,122471,141710,71411,143797,116613,35419,182877,141371,46409,154882,156653,119680,165425,196322,166684,147845,114026,137221,167716,167648,139928,81175,91008,126262,141470,59540,210949,23476,115831,80705,95887,108056,68106,40820,79579,2767,24,137184,213746,198782,35730,90147,75173,122997,140272,2180,180294,128956,198042,125060,78541,2394,200768,112836,209181,76787,156500,60611,208867,1892,143710,97948,49432,107934,146945,114584,103481,65508,42800,187416,193231,169342,102163,217167,99184,57826,68938,20764,110687,163924,62799,135426,200419,209704,160750,211700,181908,179474,71267,74626,169989,194753,139445,212024,70470,108240,49283,17947,47816,35815,75915,28929,166957,161934,118690,71338,183624,163251,128973,49227,84054,5084,39885,13010,140239,5801,48742,26810,73954,70567,187484,114928,190606,2150,61029,23644,39500,3860,129388,56934,89630,110164,81677,205973,209574,130651,97359,121552,182697,207467,86572,21030,196757,38296,14690,69862,203332,102871,18713,123246,47606,50382,65556,214237,100028,217121,3307,31971,201364,187349,65263,100872,202172,116119,64792,193362,8784,52213,120805,192760,202111,24631,136613,174589,54643,123948,102006,178636,81817,84689,41579,164931,190902,11058,195225,144069,136557,32018,126151,167384,78332,52094,94483,13654,174770,57091,29603,61857,165913,113812,85069,3949,43872,208244,72377,153274,31365,206869,63724,2340,11008,152325,205383,181076,124589,180461,156293,40630,202059,23456,65201,41339,131050,23960,179816,162174,88715,199114,163874,15742,15669,93196,168971,40348,140601,207478,28152,32984,127885,60994,83031,146368,206716,168990,198363,91047,211163,115007,192201,50584,5597,180682,99029,93891,189896,35319,65718,10351,127219,185237,114688,146544,77257,142274,182724,128474,49748,148557,48923,142506,91138,84850,191994,122786,72174,217933,156940,145364,42033,18055,139234,40638,117803,77508,214805,141978,96236,119034,188,203402,167254,132633,99363,192323,48636,4907,162459,95440,203428,162581,62596,84068,189006,16604,172566,141284,214708,101578,49773,179103,130429,200732,144563,40404,160704,74220,84567,107427,168009,84416,28957,25233,137153,172724,17479,11475,218940,34019,46829,8198,52973,28063,127236,50231,25887,8340,149687,168231,219230,101868,172470,128426,137840,163570,183845,154743,67092,12867,96037,57861,158847,188178,18579,3430,54719,171900,128819,218124,20960,122698,209982,167488,15912,94614,146281,194898,109039,215489,208207,185711,15283,215913,161725,69681,138679,10656,192492,18451,218747,79666,20430,56596,136179,25459,26535,190099,60559,139371,101012,49665,50625,137938,172196,65211,149429,163925,46897,146387,196396,146313,108838,195587,27659,200876,167411,151142,93837,176814,120250,184169,174389,111955,27289,43728,207739,47289,144170,105265,132302,88476,172368,41465,139958,157264,116132,92359,82615,144200,85106,97444,108093,200326,112876,216571,190281,84790,159491,63103,147625,143032,56079,98320,200725,144241,195186,66607,94374,118594,117309,104094,51,190398,9086,74575,156175,4273,135688,35411,79125,1721,78730,58166,56314,82550,5688,202250,215466,180805,215570,72340,145139,134026,145836,149597,149844,40557,24716,9892,120683,93233,40922,36282,69465,96228,152608,65498,212618,78722,24092,149549,104804,149176,56999,192483,20237,189682,85116,207855,126493,45325,127998,216453,180773,184247,151055,214831,24999,156036,168363,78746,9796,191572,211901,151069,100020,20572,204504,14659,202087,87658,91207,214208,113850,109188,31598,201953,59432,195999,171122,157889,166169,203795,165725,100266,14189,79897,137463,10795,141539,148088,39211,85676,131701,182915,57880,40159,13243,176589,191778,102991,116815,197352,156010,150799,132311,58218,113229,75964,206141,103369,38800,34652,88124,21638,124742,149799,195379,12926,143694,104645,68827,22302,104471,131549,113899,202853,182942,34177,199762,35355,126070,146616,51932,127889,103630,170983,149770,51406,147106,161479,192850,137428,64551,63521,18781,69797,203626,112739,84780,179499,150999,23486,29499,176094,108418,111028,95752,91306,49935,84079,90315,155681,171811,6345,201767,79906,44369,91676,161538,215373,210448,163363,38225,115464,11764,214478,61207,52346,179662,97994,40892,198635,190968,22495,29457,193876,14296,63713,162792,38712,218220,214993,75414,3600,89986,181874,210837,29170,79984,117980,123519,5963,86999,54790,193989,141034,55541,53425,143241,137171,196701,67805,143864,126622,163326,191345,37053,182208,183174,175180,4820,98725,152831,173939,138097,11398,66966,44648,15757,74452,52611,206881,182590,161995,71114,13915,121886,171640,2184,17458,101880,173112,71122,173796,173041,94429,28283,5001,76675,170239,101324,60599,99798,133310,26110,127051,156527,139345,23074,48868,52133,131242,120058,146980,186769,5633,10420,64852,79934,87443,128432,103127,195137,41157,4816,90906,179388,92251,93160,163900,79253,61919,131400,72319,212832,16255,164047,67549,101494,213610,7724,145474,73451,39769,89597,102982,172210,35982,23014,191336,100193,136503,177732,160549,163977,179967,59618,109401,63009,150181,40042,111795,81889,158693,86320,117747,142143,34543,54200,171041,36327,152313,141912,179319,209742,53034,36204,81002,180494,184847,137312,35455,112738,201863,38050,172027,148070,41431,83229,17777,146103,139712,188597,31034,48269,83256,38147,190599,200654,83542,108538,107807,79520,116130,195457,215403,166945,193551,71774,39834,204238,49265,27989,34483,48430,6867,199871,160427,161728,69956,158755,148646,160527,60532,156332,60865,180191,6837,94178,20834,180392,204426,124938,35150,131681,158521,122883,186651,170280,48228,101570,128829,124216,36048,42935,50661,59316,8448,107109,3935,70741,108277,58341,194606,175987,53171,198646,53624,17337,46419,115155,161758,128435,180319,85169,119942,64692,112081,106133,203480,10964,159365,192762,216092,99712,188796,82024,103324,160302,194633,149929,140550,137686,168137,184578,206286,10095,136485,28656,118386,119773,216795,11119,64414,135974,140999,103707,192720,10289,146197,101554,135299,200582,165320,12922,49768,76330,115226,104343,79837,202814,216955,74950,120939,43643,110745,77180,160729,14299,180867,135520,37287,121524,217551,169371,118409,23708,189437,131244,29667,47125,204139,153471,182482,143237,35469,103040,84793,118197,6530,46891,118929,154747,186812,52935,119192,23640,65557,84954,181329,170020,186507,208794,42990,188317,192220,32386,23905,29927,185133,117807,166218,95143,22362,137315,185408,100686,190545,207373,134130,160451,41916,99268,144269,99773,11509,201394,44796,87868,148240,172891,39345,42499,37907,46888,201629,186658,13450,109990,80756,81138,191407,201491,116983,75702,164220,47368,30816,34490,53176,926,163200,200643,66351,3200,120230,130450,130716,142475,211669,1507,159379,129359,62358,56852,154366,172876,124705,118944,107518,115834,20815,157605,157313,1723,64150,143482,72905,162467,164793,146158,214989,216334,12823,48454,38864,14434,171252,91493,9626,38109,8501,174849,151543,219680,158147,78234,86798,22990,12853,28544,109302,46726,42293,176184,74548,160193,173257,181276,183460,155967,150020,206387,147371,46509,119364,196433,151772,94421,9395,24412,201460,78458,84625,96618,99336,215027,17738,198821,112020,144302,17019,76764,123421,35270,31987,182090,71937,86300,138163,108464,137532,55009,182407,54801,160446,15252,28652,149780,191293,164882,138660,22982,126789,189811,180526,158538,206617,185284,75284,111949,159604,82965,68068,137728,150770,92760,34003,161032,147491,130127,136237,191151,165637,129107,211807,47161,159752,61006,185751,76149,50739,414,30987,6287,121866,163960,208858,30913,193322,54350,209600,161299,108693,11925,87501,91658,146909,211299,167117,55663,2961,167934,52740,180883,73826,66694,207145,17782,80680,78717,132825,144177,84914,201072,1607,29841,179787,87368,89485,103821,149139,161478,7632,218115,46040,86642,218582,88588,176284,125595,124387,208995,9321,169253,35587,126221,40733,164416,120748,140762,87846,121233,205518,213428,108515,194664,14685,218917,136596,78471,148096,190045,177785,82707,140479,4906,54594,151361,179697,86958,209593,208609,92692,17357,39139,73520,8171,88906,98367,164236,140056,43507,149679,158777,102122,49120,156025,30630,205086,197011,80814,184881,82444,186080,8291,71171,108493,22753,130299,62296,97408,118464,217437,79332,56659,137651,79250,45878,29420,98098,205768,9802,52978,24275,116332,132643,6536,209459,126527,131086,40390,89159,126739,192005,137265,81358,203301,85809,37563,201714,25756,215518,135304,177397,34791,53788,189941,9582,120420,101603,187086,136929,30365,1077,11542,180106,155788,22826,87369,122388,12095,27027,136466,200849,195442,81914,5893,151403,198834,81826,167470,84116,52574,50474,75121,72522,139532,83663,202581,115635,175931,30841,71474,39143,41286,71732,30005,3999,163301,146702,134757,26837,54153,142220,91833,144012,83488,71054,37108,184642,210764,150492,64179,44898,18675,81625,160763,88798,141709,116299,13297,34098,7600,123400,177077,13226,27768,100558,22575,102738,46919,184727,156269,155387,165011,219198,180912,183334,116873,161563,93224,81904,87118,214287,51943,186882,198810,175864,212153,139849,20590,198113,200390,122079,148842,34549,24357,213084,210331,183133,122755,25496,119428,81367,179431,168930,186178,200931,123350,60737,58666,37675,161727,67646,95276,102780,48064,183399,174373,86593,206773,85134,118825,127796,7933,17135,124606,29592,92128,125475,120911,44774,73992,201311,91387,158450,150648,33395,133155,26983,106348,214316,187242,208865,28044,85424,145791,170249,18592,207731,79198,139656,120841,105071,154805,108381,62541,162758,63882,143641,41686,126451,100615,17317,89316,54588,213363,181943,154919,113022,205603,114125,148691,206392,89705,142744,181962,61906,45373,173699,175532,20419,193820,186622,150520,219348,29969,155232,156257,121348,85295,145512,25599,215488,116710,46182,80749,106376,18247,156830,28506,179701,213402,156344,202203,214922,99329,85828,129576,169080,173058,16117,23732,57045,214180,36416,152244,197915,166758,212294,48,198463,38463,173246,118744,139265,76415,102231,161368,117220,51292,53912,178655,82994,78573,149494,208233,178142,79835,113961,188164,90246,152758,2128,160298,133804,105957,33437,77118,16360,96789,143886,13927,221,184380,146808,215155,94173,28756,73979,8174,61667,113171,184083,154707,103355,163896,83395,186587,100346,106439,100024,114082,142899,55369,173667,45458,124484,36456,49342,145600,100819,94447,109375,40440,75620,175653,32911,35200,3041,18807,169971,4117,37747,189094,127585,172471,114487,35592,22194,35971,168360,211272,109014,119545,206812,158447,111595,163380,47132,119116,133540,111181,144716,214976,117976,100850,93141,59176,48704,111351,158485,1947,151503,130588,191604,62062,186680,131525,70582,107698,143174,151087,96198,8771,9926,148937,126685,184780,96620,29665,152113,126230,181009,37247,191413,191704,137994,196760,204176,73248,121880,135515,20508,137194,159926,23185,64285,123850,165324,194428,173994,175300,20468,179325,166117,141347,209216,151946,81919,13834,90058,140255,204086,2826,122911,119022,134391,123006,150461,192365,86904,110366,92514,209249,77232,75814,43911,9260,182203,156666,196457,79710,190571,205788,147772,68273,16100,164967,124036,151556,44323,150171,13531,172050,76481,212047,193299,163660,7156,75187,127076,21629,33245,132538,20148,205458,72836,66073,34744,161427,73655,72146,67359,3723,158288,45773,155128,193715,39619,172793,117087,96932,211770,45979,68485,205134,183703,61917,127118,47398,209550,138251,8674,89875,22125,38029,156118,215409,40409,219575,130009,197331,108251,100444,62376,173189,115391,190316,167136,43852,55274,151455,192508,90459,14241,168364,69309,198045,214835,30407,13768,43935,121638,104528,124841,79001,124850,195598,197244,139597,61252,99562,179178,184970,102254,68845,139052,113962,3109,75705,209009,117655,100140,177862,82755,139002,179011,10752,25523,139275,167096,92753,113190,160035,170935,71457,87809,70864,70015,210322,25762,66936,171532,102538,120876,202676,177023,196369,63385,192495,176131,154330,201981,51098,136302,88449,177659,47164,82121,143553,2772,176370,218022,17813,3135,107806,25611,137694,27238,179699,116901,92147,57518,190120,74264,66033,50614,8923,6212,162371,24101,14911,155749,172125,219010,56839,157125,31926,185617,31451,158514,34054,98555,172069,169844,148391,210554,48435,29057,182361,48303,66410,77219,122409,173140,19848,167452,118867,72452,201682,111931,181461,201819,201882,84778,67459,204664,19397,72689,203713,16354,100348,101866,87203,15744,67695,213070,69380,46643,115494,18192,128901,159986,102926,46777,169209,128989,28424,37715,141723,201512,113050,57248,27852,115757,187537,67894,27076,9138,19115,18458,149774,190721,65692,203748,116885,4919,130469,39021,171144,167860,227,9918,174222,102362,69844,39742,183515,86463,102584,150638,970,185494,202110,128628,125341,105883,190779,204717,212933,205883,18376,141591,59204,867,100774,180685,214427,126370,31772,140326,94088,87409,178739,79384,105018,14468,165746,169626,23767,95775,140149,89145,21091,120491,199348,200935,76010,169814,53246,46299,145352,10310,100706,173207,50070,78347,182421,102263,27988,165312,180575,127533,22246,165882,114814,49666,36685,9358,50618,6649,188675,155437,62517,128249,101913,16513,207811,141702,168665,80669,130989,139969,218693,212802,117783,4303,9710,88941,14694,88530,206511,189008,11223,106301,30982,55490,211577,171402,174932,137132,120159,177754,140656,39047,194262,151872,61075,18883,9667,192796,37117,128160,171398,4997,123527,162102,12006,63617,183840,170075,34219,3765,45154,165848,13096,80566,215151,25504,119586,160229,44745,218982,154938,16766,214381,214294,27874,29931,130323,155969,209498,161383,57821,16543,210882,53121,111728,88432,212735,219125,193443,177837,105384,84996,31667,179329,150986,143100,36000,84302,219264,207906,55897,200337,11950,155473,45468,87950,16573,71799,37024,198587,108830,44053,182529,96974,145072,43255,114992,63869,152564,10668,26205,207369,165726,160547,98600,203859,41354,4097,127655,21160,1664,1893,47032,23367,185267,155642,74378,109110,181797,211658,207681,202522,124783,181497,33264,54043,37509,167726,71027,137124,118074,108011,62718,45756,160048,181318,9862,156309,47186,188009,131042,160445,164414,216657,205149,78957,91187,109782,1198,146200,61771,50306,30803,138988,25082,162246,94870,199525,141485,133048,57481,121691,201937,164223,172838,54554,72463,13911,50782,29538,208827,179632,124364,130420,46222,41880,175673,11302,14163,169046,95471,54942,113946,65898,87633,161695,211066,198812,118789,11355,27373,194095,13751,140865,41430,205912,7619,127155,134978,30425,18575,61682,189111,116607,155374,151981,85451,109855,52101,195844,177432,85376,60309,204203,28870,85493,59711,25706,62159,193396,108159,52417,98443,57845,123467,153378,126984,14244,114259,37995,56666,159759,119273,46677,101536,97906,35757,119993,188696,50896,165680,94002,70698,53008,88051,33347,8413,181458,149072,15637,73558,35783,29126,122037,120213,109634,199318,90066,57633,77374,158239,41101,206974,208290,167247,160308,44547,151551,144901,60215,22668,107923,136778,128244,56739,155079,124122,219411,121310,73182,3802,191369,112226,192984,64658,64967,85798,113835,151903,58302,155129,39051,150392,1331,109516,11593,10501,123378,192276,201485,132177,8813,111594,139230,103368,125637,212262,209195,182623,215296,25689,118604,129836,63048,21781,146404,197415,3256,181707,191833,106592,57149,34789,93953,77539,62668,169774,23784,4175,125772,175383,79911,42415,184367,120398,149521,195896,153002,6399,119233,202331,36339,184322,106262,48482,207056,97103,164314,121712,184523,166042,357,164444,18533,87269,8575,77749,111647,197365,70723,182990,160563,142196,150792,208267,54302,206488,179480,21222,162648,88000,192875,14368,134935,176856,95417,13516,29089,100389,95837,121989,162302,10714,95832,96058,45945,169914,146483,184778,17565,146999,185948,202788,53664,87799,62556,141461,29768,170141,125434,197204,152074,54996,197439,195301,199461,59097,84007,92461,70086,132006,178317,110868,72651,14276,16611,97730,90499,118596,127461,131452,138066,217464,32167,123250,186076,139435,216368,139684,76615,118483,162778,96507,209689,176503,181724,212020,17203,49375,120028,22887,21016,102415,46874,75595,20171,28669,43192,72336,29366,182735,38785,130109,39452,53544,136495,164436,5948,76852,109192,178887,141933,34610,124231,90885,110134,116355,193386,5631,76419,49561,16501,5776,210236,60903,218876,136215,168996,88048,107922,151703,89860,171903,169358,206261,85539,127440,155545,174549,182678,138979,143777,116721,120063,79713,148701,34591,206922,80345,45019,206901,68159,59449,126941,97491,75535,201785,12404,38886,92568,74549,162170,42150,186597,140330,74775,45969,185627,181628,174681,208262,40235,42271,62365,30649,10787,83001,150608,33154,115576,68264,96907,48364,53780,147452,185371,9093,174470,83084,90137,38782,97757,117552,141522,42623,139891,79052,31847,6533,106053,130919,201462,38111,120072,83943,40908,1125,193592,143168,133014,86843,169671,184929,77478,124232,75976,197698,150399,174820,38557,209604,148172,115895,88784,180655,29442,191795,10506,36548,177735,144430,30638,80761,6500,105284,161879,57767,116909,93688,24243,85555,14426,159423,51157,152943,138001,58233,204808,201089,72856,51305,5910,21912,141141,116774,177488,176044,22905,134162,103395,50973,128572,83298,158117,158617,134101,14692,174788,94650,104217,38783,187685,112103,96896,101217,97735,3195,74619,90418,92105,80762,116805,147071,22963,56543,190802,116449,194387,108949,210681,115999,75663,42936,211202,190841,111105,190999,108578,36586,219024,153925,156745,22564,37654,54469,79490,124416,137704,87927,154842,198210,98506,23186,1735,107337,167459,108662,45847,145703,121141,186186,148875,197982,35894,52234,48760,218178,100078,75881,145192,128604,11071,215691,216596,79165,218781,153408,79669,129525,146612,193948,121491,58856,7035,207413,641,40376,68859,78687,167917,218691,166870,123873,32552,171393,174565,75415,109355,47021,128945,180313,169026,64159,7456,31288,200466,72670,129250,17491,156995,195445,94833,152322,125956,148261,114312,67609,164625,179807,124791,181992,214507,31348,134355,192140,150801,4599,68036,42689,152477,15338,38308,83802,120859,181739,35519,93686,42810,13864,71478,94676,109973,166788,93319,12770,75984,76457,138834,167495,20139,134707,50807,200918,164634,48037,200458,156961,197297,43813,142530,84569,82892,216746,179560,209683,78924,71161,179684,182967,180445,25122,140095,152055,92177,215827,37674,3381,35900,135835,73771,97783,204159,127113,89211,42737,26848,57114,131958,57963,217992,49392,208432,146448,160806,180788,217950,137796,139485,186789,144470,184094,198261,210391,29875,109404,120715,20027,30556,214655 diff --git a/data/molecules/val.index b/data/molecules/val.index new file mode 100644 index 0000000..78db245 --- /dev/null +++ b/data/molecules/val.index @@ -0,0 +1 @@ +20952,3648,819,24299,9012,8024,7314,4572,24132,3358,22174,24270,17870,2848,19349,13825,1041,976,3070,7164,7623,16559,19726,869,18390,6515,23462,21295,22981,17856,13746,7223,14719,19309,9115,212,5231,22876,13848,11149,9105,5094,7055,11029,3349,3039,12449,3169,11763,11270,19782,8667,1423,23911,15054,17571,4090,12403,2582,18089,9606,20599,20267,11850,18918,6300,23087,2279,1501,21668,7467,9482,2614,7628,3309,12455,9108,14857,20830,11954,5329,12130,11641,6865,21960,8748,22997,22398,21234,2339,19960,20806,5607,17502,23892,8021,5354,15147,12433,8845,20971,22549,18250,7196,22433,10626,1832,7505,1051,10336,13145,8773,2168,6913,18585,23524,10311,6967,21477,16358,12964,21064,15035,4681,8679,4575,8081,24411,18394,17661,8609,19155,14038,19121,13087,11861,7186,4532,16696,16171,2978,1543,3592,5008,20560,5242,22298,13833,19543,2081,12608,12504,19526,15337,17338,8238,18128,376,22291,23616,3753,22338,17595,8743,21003,11146,3655,9617,14246,5182,14867,106,23661,23582,8630,16403,5854,16635,3486,20489,9779,20937,19954,6517,12252,5293,17674,17378,18,19626,10621,16010,638,3665,11894,10076,7846,1898,7892,18591,2580,2806,23983,15924,2267,17455,4120,4207,21618,15574,18015,5410,8685,17290,19876,13865,6940,17671,23918,22605,6591,23361,10214,13074,22009,12236,14355,16959,14794,3965,8123,7362,2098,11078,689,19277,18150,7540,19282,7216,235,2326,23194,20679,1929,7501,2208,1029,10827,2321,16847,7798,9125,21921,15906,7020,17669,4335,23702,18711,18881,15488,7962,15498,13338,6239,3090,3176,21593,14124,11609,13879,13470,15303,23890,1775,22064,21412,21174,3224,1986,13193,23862,11118,3580,8147,6278,6232,17573,14700,4593,13824,6012,9127,15159,8185,2470,14520,18033,3208,1657,21369,17713,483,3056,7745,5449,13317,15913,15773,7004,13141,1921,5394,12418,70,12793,8690,14909,9347,13861,22825,23937,18211,21688,23540,15947,5072,6222,9722,7133,1916,18978,24108,17766,1997,10276,1873,1643,19142,15623,16477,17403,5158,1863,16640,2625,6089,2245,19498,2226,22125,7707,13230,3928,18667,8067,18970,19481,1302,20295,2686,13737,21540,19125,18521,17130,10366,8544,6693,21945,23468,10295,7821,8703,12969,4288,21151,9830,14982,10360,2377,305,15017,20354,18448,3276,2400,17617,6984,16576,4340,11436,2254,8004,12108,9338,5169,14358,17800,23053,9912,20043,21429,17332,256,21884,18173,9810,21737,3394,4400,8666,3782,3507,24327,5093,8924,9232,19819,6901,23514,11235,6671,22527,20782,8650,16561,16008,8228,1664,3024,20784,9066,1444,116,10929,4286,20876,8583,5294,24288,14478,18077,23123,14014,18379,316,2465,22643,4884,17877,1180,12098,19087,18105,4852,14083,4176,1370,10101,11948,1307,11724,6883,22349,8176,21854,3368,11589,18346,13316,20337,5064,7757,5324,5801,13510,812,5877,24135,10885,13491,21951,24086,8131,8742,5216,22979,3542,12535,1268,15423,7288,6539,15083,11457,10000,7457,7304,775,21627,6328,13056,10756,9129,2274,9146,11506,21020,16692,13096,22266,17570,10851,904,3779,8559,5851,19024,8698,1253,3552,19548,14239,11327,23872,10278,14299,19864,16758,3789,12622,18893,6228,8346,1454,23225,14288,55,17036,17643,22506,23574,24312,24151,21975,6456,11934,14132,2292,21765,10819,20419,10286,4083,23584,9840,16617,10134,21852,13382,10688,13185,22846,9688,18166,4170,6286,13777,21788,12423,22194,5702,20169,18648,9861,13306,17954,13,9957,9401,6887,14086,19004,19879,21453,10559,15236,14476,14488,22138,7002,16750,15505,24115,5560,21589,2778,9299,16890,21753,20740,20291,10983,3060,7696,22046,10171,7361,6525,4828,800,1514,8023,15569,20030,2386,14923,13580,20636,18863,6371,23538,22818,12582,16199,13095,7994,4835,21497,22532,181,3492,13931,7170,5763,22803,16972,15222,1645,18265,8165,3976,14957,4369,15225,21875,17404,18314,19511,10397,14502,20075,23569,16540,13983,17952,14611,5215,24368,15554,14747,8493,8101,20894,9086,17081,15879,20537,7839,8998,14413,2538,23381,9362,7683,8903,11005,10476,17699,2640,4534,4942,7577,12551,22739,5007,23147,7010,2104,13594,13356,10842,17780,15267,13624,2040,6777,13767,12762,19139,22790,640,18864,12464,15629,193,11526,9784,12779,13730,17636,24072,17895,19767,7226,15998,7190,8943,14281,951,12742,11014,21917,22254,13248,23729,5408,15315,4182,20390,883,12911,19395,18493,21725,888,2750,21061,14044,4446,15128,5954,1647,8524,12422,10726,6935,14899,10710,11059,9117,13813,8266,2683,15411,635,17675,1706,11467,7347,21303,2248,21356,1319,1016,8102,6532,667,20359,4993,7816,4136,15517,21936,3748,18480,7142,15238,22920,8396,12087,5498,19853,19898,23539,5366,10192,18962,841,10222,18867,22195,12298,12997,23429,6498,2490,19401,22630,20553,7957,3339,22845,9882,22421,19674,3966,18544,1345,11377,17456,14037,21676,12142,2259,16579,21218,11181,414,13764,16062,3458,14205,11868,20826,15064,23177,5013,14270,5771,24045,17096,21314,8850,20182,17634,15843,15232,14272,23954,19414,8794,10561,8044,2839,9139,14771,7990,15227,18672,19999,21895,11023,940,16197,10650,5958,15976,6950,11626,8465,11152,9163,19534,22976,9052,18212,332,16928,6260,2805,7908,23595,16009,18192,7874,22629,15600,21164,23325,16083,14685,565,3049,9641,7261,13251,22668,7972,10033,21756,19056,12092,15507,18136,17397,11264,13942,24442,18035,10839,11528,23031,14868,8877,10047,8237,7554,3953,23635,6310,10339,3917,24342,17559,22615,6066,6276,7090,24202,15866,9060,23743,19319,17191,19555,9273,3294,6360,9707,7454,11825,5879,9904,463,23200,4147,8988,1491,1786,18132,9572,22852,4137,20901,16085,3361,401,18810,9317,15381,15686,14433,11164,6041,1683,8273,15654,3738,2141,13130,16113,2427,18907,20625,22493,1756,4971,4888,18443,9956,2791,8132,3881,18286,13637,19867,19533,20264,7395,17123,14762,14507,9743,19284,14051,10006,18632,20349,1973,19976,24251,3251,6808,20496,6914,8671,21640 diff --git a/data/ogb_mol.py b/data/ogb_mol.py new file mode 100644 index 0000000..c20ec4c --- /dev/null +++ b/data/ogb_mol.py @@ -0,0 +1,254 @@ +import time +import dgl +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +from ogb.graphproppred import DglGraphPropPredDataset, Evaluator + +from scipy import sparse as sp +import numpy as np +import networkx as nx +from tqdm import tqdm + +class OGBMOLDGL(torch.utils.data.Dataset): + def __init__(self, data, split): + self.split = split + self.data = [g for g in data[self.split]] + self.graph_lists = [] + self.graph_labels = [] + for g in self.data: + if g[0].number_of_nodes() > 5: + self.graph_lists.append(g[0]) + self.graph_labels.append(g[1]) + self.n_samples = len(self.graph_lists) + + def __len__(self): + """Return the number of graphs in the dataset.""" + return self.n_samples + + def __getitem__(self, idx): + """ + Get the idx^th sample. + Parameters + --------- + idx : int + The sample index. + Returns + ------- + (dgl.DGLGraph, int) + DGLGraph with node feature stored in `feat` field + And its label. + """ + return self.graph_lists[idx], self.graph_labels[idx] + +def add_eig_vec(g, pos_enc_dim): + """ + Graph positional encoding v/ Laplacian eigenvectors + This func is for eigvec visualization, same code as positional_encoding() func, + but stores value in a diff key 'eigvec' + """ + + # Laplacian + A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(g.number_of_nodes()) - N * A * N + + # Eigenvectors with numpy + EigVal, EigVec = np.linalg.eig(L.toarray()) + idx = EigVal.argsort() # increasing order + EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) + g.ndata['eigvec'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() + + # zero padding to the end if n < pos_enc_dim + n = g.number_of_nodes() + if n <= pos_enc_dim: + g.ndata['eigvec'] = F.pad(g.ndata['eigvec'], (0, pos_enc_dim - n + 1), value=float('0')) + + return g + + +def lap_positional_encoding(g, pos_enc_dim): + """ + Graph positional encoding v/ Laplacian eigenvectors + """ + + # Laplacian + A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(g.number_of_nodes()) - N * A * N + + # Eigenvectors with numpy + EigVal, EigVec = np.linalg.eig(L.toarray()) + idx = EigVal.argsort() # increasing order + EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) + g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() + + return g + + +def init_positional_encoding(g, pos_enc_dim, type_init): + """ + Initializing positional encoding with RWPE + """ + + n = g.number_of_nodes() + + if type_init == 'rand_walk': + # Geometric diffusion features with Random Walk + A = g.adjacency_matrix(scipy_fmt="csr") + Dinv = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -1.0, dtype=float) # D^-1 + RW = A * Dinv + M = RW + + # Iterate + nb_pos_enc = pos_enc_dim + PE = [torch.from_numpy(M.diagonal()).float()] + M_power = M + for _ in range(nb_pos_enc-1): + M_power = M_power * M + PE.append(torch.from_numpy(M_power.diagonal()).float()) + PE = torch.stack(PE,dim=-1) + g.ndata['pos_enc'] = PE + + return g + + +def make_full_graph(graph, adaptive_weighting=None): + g, label = graph + + full_g = dgl.from_networkx(nx.complete_graph(g.number_of_nodes())) + + # Copy over the node feature data and laplace eigvecs + full_g.ndata['feat'] = g.ndata['feat'] + + try: + full_g.ndata['pos_enc'] = g.ndata['pos_enc'] + except: + pass + + try: + full_g.ndata['eigvec'] = g.ndata['eigvec'] + except: + pass + + # Initalize fake edge features w/ 0s + full_g.edata['feat'] = torch.zeros(full_g.number_of_edges(), 3, dtype=torch.long) + full_g.edata['real'] = torch.zeros(full_g.number_of_edges(), dtype=torch.long) + + # Copy real edge data over, and identify real edges! + full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['feat'] = g.edata['feat'] + full_g.edges[g.edges(form='uv')[0].tolist(), g.edges(form='uv')[1].tolist()].data['real'] = torch.ones( + g.edata['feat'].shape[0], dtype=torch.long) # This indicates real edges + + # This code section only apply for GraphiT -------------------------------------------- + if adaptive_weighting is not None: + p_steps, gamma = adaptive_weighting + + n = g.number_of_nodes() + A = g.adjacency_matrix(scipy_fmt="csr") + + # Adaptive weighting k_ij for each edge + if p_steps == "qtr_num_nodes": + p_steps = int(0.25*n) + elif p_steps == "half_num_nodes": + p_steps = int(0.5*n) + elif p_steps == "num_nodes": + p_steps = int(n) + elif p_steps == "twice_num_nodes": + p_steps = int(2*n) + + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + I = sp.eye(n) + L = I - N * A * N + + k_RW = I - gamma*L + k_RW_power = k_RW + for _ in range(p_steps - 1): + k_RW_power = k_RW_power.dot(k_RW) + + k_RW_power = torch.from_numpy(k_RW_power.toarray()) + + # Assigning edge features k_RW_eij for adaptive weighting during attention + full_edge_u, full_edge_v = full_g.edges() + num_edges = full_g.number_of_edges() + + k_RW_e_ij = [] + for edge in range(num_edges): + k_RW_e_ij.append(k_RW_power[full_edge_u[edge], full_edge_v[edge]]) + + full_g.edata['k_RW'] = torch.stack(k_RW_e_ij,dim=-1).unsqueeze(-1).float() + # -------------------------------------------------------------------------------------- + + return full_g, label + +class OGBMOLDataset(Dataset): + def __init__(self, name, features='full'): + + start = time.time() + print("[I] Loading dataset %s..." % (name)) + self.name = name.lower() + + self.dataset = DglGraphPropPredDataset(name=self.name) + + if features == 'full': + pass + elif features == 'simple': + print("[I] Retaining only simple features...") + # only retain the top two node/edge features + for g in self.dataset.graphs: + g.ndata['feat'] = g.ndata['feat'][:, :2] + g.edata['feat'] = g.edata['feat'][:, :2] + + split_idx = self.dataset.get_idx_split() + + self.train = OGBMOLDGL(self.dataset, split_idx['train']) + self.val = OGBMOLDGL(self.dataset, split_idx['valid']) + self.test = OGBMOLDGL(self.dataset, split_idx['test']) + + self.evaluator = Evaluator(name=self.name) + + print("[I] Finished loading.") + print("[I] Data load time: {:.4f}s".format(time.time()-start)) + + # form a mini batch from a given list of samples = [(graph, label) pairs] + def collate(self, samples): + # The input samples is a list of pairs (graph, label). + graphs, labels = map(list, zip(*samples)) + batched_graph = dgl.batch(graphs) + labels = torch.stack(labels) + tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] + tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] + snorm_n = torch.cat(tab_snorm_n).sqrt() + + return batched_graph, labels, snorm_n + + def _add_lap_positional_encodings(self, pos_enc_dim): + + # Graph positional encoding v/ Laplacian eigenvectors + self.train = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.train] + self.val = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.val] + self.test = [(lap_positional_encoding(g, pos_enc_dim), label) for g, label in self.test] + + def _add_eig_vecs(self, pos_enc_dim): + + # Graph positional encoding v/ Laplacian eigenvectors + self.train = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.train] + self.val = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.val] + self.test = [(add_eig_vec(g, pos_enc_dim), label) for g, label in self.test] + + + def _init_positional_encodings(self, pos_enc_dim, type_init): + + # Initializing positional encoding randomly with l2-norm 1 + self.train = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.train] + self.val = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.val] + self.test = [(init_positional_encoding(g, pos_enc_dim, type_init), label) for g, label in self.test] + + def _make_full_graph(self, adaptive_weighting=None): + self.train = [make_full_graph(graph, adaptive_weighting) for graph in self.train] + self.val = [make_full_graph(graph, adaptive_weighting) for graph in self.val] + self.test = [make_full_graph(graph, adaptive_weighting) for graph in self.test] + + + \ No newline at end of file diff --git a/data/script_download_ZINC.sh b/data/script_download_ZINC.sh new file mode 100644 index 0000000..33ba6d2 --- /dev/null +++ b/data/script_download_ZINC.sh @@ -0,0 +1,21 @@ + + +# Command to download dataset: +# bash script_download_ZINC.sh + + +DIR=molecules/ +cd $DIR + + +FILE=ZINC.pkl +if test -f "$FILE"; then + echo -e "$FILE already downloaded." +else + echo -e "\ndownloading $FILE..." + curl https://data.dgl.ai/dataset/benchmarking-gnns/ZINC.pkl -o ZINC.pkl -J -L -k +fi + + + + diff --git a/docs/01_repo_installation.md b/docs/01_repo_installation.md new file mode 100644 index 0000000..2749acc --- /dev/null +++ b/docs/01_repo_installation.md @@ -0,0 +1,84 @@ +# Repo installation + + + +
+ +## 1. Setup Conda + +``` +# Conda installation + +# For Linux +curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh + +# For OSX +curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + +chmod +x ~/miniconda.sh +./miniconda.sh + +source ~/.bashrc # For Linux +source ~/.bash_profile # For OSX +``` + + +
+ +## 2. Setup Python environment for CPU + +``` +# Clone GitHub repo +conda install git +git clone https://github.com/vijaydwivedi75/gnn-lspe.git +cd gnn-lspe + +# Install python environment +conda env create -f environment_cpu.yml + +# Activate environment +conda activate gnn_lspe +``` + + + +
+ +## 3. Setup Python environment for GPU + +DGL 0.6.1+ requires CUDA **10.2**. + +For Ubuntu **18.04** + +``` +# Setup CUDA 10.2 on Ubuntu 18.04 +sudo apt-get --purge remove "*cublas*" "cuda*" +sudo apt --purge remove "nvidia*" +sudo apt autoremove +wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-repo-ubuntu1804_10.2.89-1_amd64.deb +sudo dpkg -i cuda-repo-ubuntu1804_10.2.89-1_amd64.deb +sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub +sudo apt update +sudo apt install -y cuda-10-2 +sudo reboot +cat /usr/local/cuda/version.txt # Check CUDA version is 10.2 + +# Clone GitHub repo +conda install git +git clone https://github.com/vijaydwivedi75/gnn-lspe.git +cd gnn-lspe + +# Install python environment +conda env create -f environment_gpu.yml + +# Activate environment +conda activate gnn_lspe +``` + + + + + + +


+ diff --git a/docs/02_download_datasets.md b/docs/02_download_datasets.md new file mode 100644 index 0000000..e120e4e --- /dev/null +++ b/docs/02_download_datasets.md @@ -0,0 +1,18 @@ +# Download datasets + +OGBG-MOL* datasets are automatically downloaded from OGB. For ZINC, use the following script. + +
+ +## 1. ZINC molecular dataset +ZINC size is 58.9MB. + +``` +# At the root of the project +cd data/ +bash script_download_ZINC.sh +``` +Script [script_download_ZINC.sh](../data/script_download_ZINC.sh) is located here. + + +


diff --git a/docs/03_run_codes.md b/docs/03_run_codes.md new file mode 100644 index 0000000..ea22d44 --- /dev/null +++ b/docs/03_run_codes.md @@ -0,0 +1,87 @@ +# Reproducibility + + +
+ +## 1. Usage + + +
+ +### In terminal + +``` +# Run the main file (at the root of the project) +python main_ZINC_graph_regression.py --config 'configs/GatedGCN_ZINC_LSPE.json' # for CPU +python main_ZINC_graph_regression.py --gpu_id 0 --config 'configs/GatedGCN_ZINC_LSPE.json' # for GPU +``` +The training and network parameters for each experiment is stored in a json file in the [`configs/`](../configs) directory. + + + + +
+ +## 2. Output, checkpoints and visualizations + +Output results are located in the folder defined by the variable `out_dir` in the corresponding config file (eg. [`configs/GatedGCN_ZINC_LSPE.json`](../configs/GatedGCN_ZINC_LSPE.json) file). + +If `out_dir = 'out/GatedGCN_ZINC_LSPE_noLapEigLoss/'`, then + +#### 2.1 To see checkpoints and results +1. Go to`out/GatedGCN_ZINC_LSPE_noLapEigLoss/results` to view all result text files. +2. Directory `out/GatedGCN_ZINC_LSPE_noLapEigLoss/checkpoints` contains model checkpoints. + +#### 2.2 To see the training logs in Tensorboard on local machine +1. Go to the logs directory, i.e. `out/GatedGCN_ZINC_LSPE_noLapEigLoss/logs/`. +2. Run the commands +``` +source activate gnn_lspe +tensorboard --logdir='./' --port 6006 +``` +3. Open `http://localhost:6006` in your browser. Note that the port information (here 6006 but it may change) appears on the terminal immediately after starting tensorboard. + + +#### 2.3 To see the training logs in Tensorboard on remote machine +1. Go to the logs directory, i.e. `out/GatedGCN_ZINC_LSPE_noLapEigLoss/logs/`. +2. Run the [script](../scripts/TensorBoard/script_tensorboard.sh) with `bash script_tensorboard.sh`. +3. On your local machine, run the command `ssh -N -f -L localhost:6006:localhost:6006 user@xx.xx.xx.xx`. +4. Open `http://localhost:6006` in your browser. Note that `user@xx.xx.xx.xx` corresponds to your user login and the IP of the remote machine. + + + +
+ +## 3. Reproduce results + + +``` +# At the root of the project + +bash scripts/ZINC/script_ZINC_all.sh +bash scripts/OGBMOL/script_MOLTOX21_all.sh +bash scripts/OGBMOL/script_MOLPCBA_all.sh + +``` + +Scripts are [located](../scripts/) at the `scripts/` directory of the repository. + + + + + + + + + + + + + + + + + + + +


\ No newline at end of file diff --git a/docs/gnn-lspe.png b/docs/gnn-lspe.png new file mode 100644 index 0000000..c962007 Binary files /dev/null and b/docs/gnn-lspe.png differ diff --git a/environment_cpu.yml b/environment_cpu.yml new file mode 100644 index 0000000..9986997 --- /dev/null +++ b/environment_cpu.yml @@ -0,0 +1,44 @@ +name: gnn_lspe +channels: +- pytorch +- dglteam +- conda-forge +- anaconda +- defaults +dependencies: +- python=3.7.4 +- python-dateutil=2.8.0 +- pip=19.2.3 +- pytorch=1.6.0 +- torchvision==0.7.0 +- pillow==6.1 +- dgl=0.6.1 +- numpy=1.19.2 +- matplotlib=3.1.0 +- tensorboard=1.14.0 +- tensorboardx=1.8 +- future=0.18.2 +- absl-py +- networkx=2.3 +- scikit-learn=0.21.2 +- scipy=1.3.0 +- notebook=6.0.0 +- h5py=2.9.0 +- mkl=2019.4 +- ipykernel=5.1.2 +- ipython=7.7.0 +- ipython_genutils=0.2.0 +- ipywidgets=7.5.1 +- jupyter=1.0.0 +- jupyter_client=5.3.1 +- jupyter_console=6.0.0 +- jupyter_core=4.5.0 +- plotly=4.1.1 +- scikit-image=0.15.0 +- requests==2.22.0 +- tqdm==4.43.0 +- pip: + - tensorflow==2.1.0 + - tensorflow-estimator==2.1.0 + - tensorboard==2.1.1 + - ogb==1.3.1 \ No newline at end of file diff --git a/environment_gpu.yml b/environment_gpu.yml new file mode 100644 index 0000000..5192a66 --- /dev/null +++ b/environment_gpu.yml @@ -0,0 +1,47 @@ +name: gnn_lspe +channels: +- pytorch +- dglteam +- conda-forge +- fragcolor +- anaconda +- defaults +dependencies: +- cudatoolkit=10.2 +- cudnn=7.6.5 +- python=3.7.4 +- python-dateutil=2.8.0 +- pip=19.2.3 +- pytorch=1.6.0 +- torchvision==0.7.0 +- pillow==6.1 +- dgl-cuda10.2=0.6.1 +- numpy=1.19.2 +- matplotlib=3.1.0 +- tensorboard=1.14.0 +- tensorboardx=1.8 +- future=0.18.2 +- absl-py +- networkx=2.3 +- scikit-learn=0.21.2 +- scipy=1.3.0 +- notebook=6.0.0 +- h5py=2.9.0 +- mkl=2019.4 +- ipykernel=5.1.2 +- ipython=7.7.0 +- ipython_genutils=0.2.0 +- ipywidgets=7.5.1 +- jupyter=1.0.0 +- jupyter_client=5.3.1 +- jupyter_console=6.0.0 +- jupyter_core=4.5.0 +- plotly=4.1.1 +- scikit-image=0.15.0 +- requests==2.22.0 +- tqdm==4.43.0 +- pip: + - tensorflow-gpu==2.1.0 + - tensorflow-estimator==2.1.0 + - tensorboard==2.1.1 + - ogb==1.3.1 \ No newline at end of file diff --git a/layers/gatedgcn_layer.py b/layers/gatedgcn_layer.py new file mode 100644 index 0000000..258fea4 --- /dev/null +++ b/layers/gatedgcn_layer.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl.function as fn + +""" + GatedGCN: Residual Gated Graph ConvNets + An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent) + https://arxiv.org/pdf/1711.07553v2.pdf +""" + +class GatedGCNLayer(nn.Module): + """ + Param: [] + """ + def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False, graph_norm=True): + super().__init__() + self.in_channels = input_dim + self.out_channels = output_dim + self.dropout = dropout + self.batch_norm = batch_norm + self.graph_norm = graph_norm + self.residual = residual + + if input_dim != output_dim: + self.residual = False + + self.A = nn.Linear(input_dim, output_dim, bias=True) + self.B = nn.Linear(input_dim, output_dim, bias=True) + self.C = nn.Linear(input_dim, output_dim, bias=True) + self.D = nn.Linear(input_dim, output_dim, bias=True) + self.E = nn.Linear(input_dim, output_dim, bias=True) + self.bn_node_h = nn.BatchNorm1d(output_dim) + self.bn_node_e = nn.BatchNorm1d(output_dim) + + def forward(self, g, h, p=None, e=None, snorm_n=None): + + h_in = h # for residual connection + e_in = e # for residual connection + + g.ndata['h'] = h + g.ndata['Ah'] = self.A(h) + g.ndata['Bh'] = self.B(h) + g.ndata['Dh'] = self.D(h) + g.ndata['Eh'] = self.E(h) + g.edata['e'] = e + g.edata['Ce'] = self.C(e) + + g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh')) + g.edata['e'] = g.edata['DEh'] + g.edata['Ce'] + g.edata['sigma'] = torch.sigmoid(g.edata['e']) + g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h')) + g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) + g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6) + + h = g.ndata['h'] # result of graph convolution + e = g.edata['e'] # result of graph convolution + + # GN from benchmarking-gnns-v1 + if self.graph_norm: + h = h * snorm_n + + if self.batch_norm: + h = self.bn_node_h(h) # batch normalization + e = self.bn_node_e(e) # batch normalization + + h = F.relu(h) # non-linear activation + e = F.relu(e) # non-linear activation + + if self.residual: + h = h_in + h # residual connection + e = e_in + e # residual connection + + h = F.dropout(h, self.dropout, training=self.training) + e = F.dropout(e, self.dropout, training=self.training) + + return h, None, e + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels) + + \ No newline at end of file diff --git a/layers/gatedgcn_lspe_layer.py b/layers/gatedgcn_lspe_layer.py new file mode 100644 index 0000000..78fcfe0 --- /dev/null +++ b/layers/gatedgcn_lspe_layer.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl.function as fn + +import dgl + +""" + GatedGCNLSPE: GatedGCN with LSPE +""" + +class GatedGCNLSPELayer(nn.Module): + """ + Param: [] + """ + def __init__(self, input_dim, output_dim, dropout, batch_norm, use_lapeig_loss=False, residual=False): + super().__init__() + self.in_channels = input_dim + self.out_channels = output_dim + self.dropout = dropout + self.batch_norm = batch_norm + self.residual = residual + self.use_lapeig_loss = use_lapeig_loss + + if input_dim != output_dim: + self.residual = False + + self.A1 = nn.Linear(input_dim*2, output_dim, bias=True) + self.A2 = nn.Linear(input_dim*2, output_dim, bias=True) + self.B1 = nn.Linear(input_dim, output_dim, bias=True) + self.B2 = nn.Linear(input_dim, output_dim, bias=True) + self.B3 = nn.Linear(input_dim, output_dim, bias=True) + self.C1 = nn.Linear(input_dim, output_dim, bias=True) + self.C2 = nn.Linear(input_dim, output_dim, bias=True) + + self.bn_node_h = nn.BatchNorm1d(output_dim) + self.bn_node_e = nn.BatchNorm1d(output_dim) + # self.bn_node_p = nn.BatchNorm1d(output_dim) + + def message_func_for_vij(self, edges): + hj = edges.src['h'] # h_j + pj = edges.src['p'] # p_j + vij = self.A2(torch.cat((hj, pj), -1)) + return {'v_ij': vij} + + def message_func_for_pj(self, edges): + pj = edges.src['p'] # p_j + return {'C2_pj': self.C2(pj)} + + def compute_normalized_eta(self, edges): + return {'eta_ij': edges.data['sigma_hat_eta'] / (edges.dst['sum_sigma_hat_eta'] + 1e-6)} # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij' + + def forward(self, g, h, p, e, snorm_n): + + with g.local_scope(): + + # for residual connection + h_in = h + p_in = p + e_in = e + + # For the h's + g.ndata['h'] = h + g.ndata['A1_h'] = self.A1(torch.cat((h, p), -1)) + # self.A2 being used in message_func_for_vij() function + g.ndata['B1_h'] = self.B1(h) + g.ndata['B2_h'] = self.B2(h) + + # For the p's + g.ndata['p'] = p + g.ndata['C1_p'] = self.C1(p) + # self.C2 being used in message_func_for_pj() function + + # For the e's + g.edata['e'] = e + g.edata['B3_e'] = self.B3(e) + + #--------------------------------------------------------------------------------------# + # Calculation of h + g.apply_edges(fn.u_add_v('B1_h', 'B2_h', 'B1_B2_h')) + g.edata['hat_eta'] = g.edata['B1_B2_h'] + g.edata['B3_e'] + g.edata['sigma_hat_eta'] = torch.sigmoid(g.edata['hat_eta']) + g.update_all(fn.copy_e('sigma_hat_eta', 'm'), fn.sum('m', 'sum_sigma_hat_eta')) # sum_j' sigma_hat_eta_ij' + g.apply_edges(self.compute_normalized_eta) # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij' + g.apply_edges(self.message_func_for_vij) # v_ij + g.edata['eta_mul_v'] = g.edata['eta_ij'] * g.edata['v_ij'] # eta_ij * v_ij + g.update_all(fn.copy_e('eta_mul_v', 'm'), fn.sum('m', 'sum_eta_v')) # sum_j eta_ij * v_ij + g.ndata['h'] = g.ndata['A1_h'] + g.ndata['sum_eta_v'] + + # Calculation of p + g.apply_edges(self.message_func_for_pj) # p_j + g.edata['eta_mul_p'] = g.edata['eta_ij'] * g.edata['C2_pj'] # eta_ij * C2_pj + g.update_all(fn.copy_e('eta_mul_p', 'm'), fn.sum('m', 'sum_eta_p')) # sum_j eta_ij * C2_pj + g.ndata['p'] = g.ndata['C1_p'] + g.ndata['sum_eta_p'] + + #--------------------------------------------------------------------------------------# + + # passing towards output + h = g.ndata['h'] + p = g.ndata['p'] + e = g.edata['hat_eta'] + + # GN from benchmarking-gnns-v1 + h = h * snorm_n + + # batch normalization + if self.batch_norm: + h = self.bn_node_h(h) + e = self.bn_node_e(e) + # No BN for p + + # non-linear activation + h = F.relu(h) + e = F.relu(e) + p = torch.tanh(p) + + # residual connection + if self.residual: + h = h_in + h + p = p_in + p + e = e_in + e + + # dropout + h = F.dropout(h, self.dropout, training=self.training) + p = F.dropout(p, self.dropout, training=self.training) + e = F.dropout(e, self.dropout, training=self.training) + + return h, p, e + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels) \ No newline at end of file diff --git a/layers/graphit_gt_layer.py b/layers/graphit_gt_layer.py new file mode 100644 index 0000000..972e40e --- /dev/null +++ b/layers/graphit_gt_layer.py @@ -0,0 +1,273 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import dgl.function as fn +import numpy as np + +""" + GraphiT-GT + +""" + +""" + Util functions +""" +def src_dot_dst(src_field, dst_field, out_field): + def func(edges): + return {out_field: (edges.src[src_field] * edges.dst[dst_field])} + return func + + +def scaling(field, scale_constant): + def func(edges): + return {field: ((edges.data[field]) / scale_constant)} + return func + +# Improving implicit attention scores with explicit edge features, if available +def imp_exp_attn(implicit_attn, explicit_edge): + """ + implicit_attn: the output of K Q + explicit_edge: the explicit edge features + """ + def func(edges): + return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])} + return func + + +def exp(field): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))} + return func + +def adaptive_edge_PE(field, adaptive_weight): + def func(edges): + # initial shape was: adaptive_weight: [edges,1]; data: [edges, num_heads, 1] + # repeating adaptive_weight to have: [edges, num_heads, 1] + edges.data['tmp'] = edges.data[adaptive_weight].repeat(1, edges.data[field].shape[1]).unsqueeze(-1) + return {'score_soft': edges.data['tmp'] * edges.data[field]} + return func + + +""" + Single Attention Head +""" + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, adaptive_edge_PE, attention_for): + super().__init__() + + + self.out_dim = out_dim + self.num_heads = num_heads + self.gamma = gamma + self.full_graph=full_graph + self.attention_for = attention_for + self.adaptive_edge_PE = adaptive_edge_PE + + if self.attention_for == "h": + if use_bias: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + def propagate_attention(self, g): + + + if self.full_graph: + real_ids = torch.nonzero(g.edata['real']).squeeze() + fake_ids = torch.nonzero(g.edata['real']==0).squeeze() + + else: + real_ids = g.edges(form='eid') + + g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids) + + if self.full_graph: + g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids) + + + # scale scores by sqrt(d) + g.apply_edges(scaling('score', np.sqrt(self.out_dim))) + + # Use available edge features to modify the scores for edges + g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids) + + if self.full_graph: + g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids) + + g.apply_edges(exp('score')) + + # Adaptive weighting with k_RW_eij + # Only applicable to full graph, For NOW + if self.adaptive_edge_PE and self.full_graph: + g.apply_edges(adaptive_edge_PE('score_soft', 'k_RW')) + del g.edata['tmp'] + + # Send weighted values to target nodes + eids = g.edges() + g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV')) + g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z')) + + + def forward(self, g, h, e): + + Q_h = self.Q(h) + K_h = self.K(h) + E = self.E(e) + + if self.full_graph: + Q_2h = self.Q_2(h) + K_2h = self.K_2(h) + E_2 = self.E_2(e) + + V_h = self.V(h) + + + # Reshaping into [num_nodes, num_heads, feat_dim] to + # get projections for multi-head attention + g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim) + g.edata['E'] = E.view(-1, self.num_heads, self.out_dim) + + + if self.full_graph: + g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim) + g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim) + + g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim) + + self.propagate_attention(g) + + h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) + + del g.ndata['wV'] + del g.ndata['z'] + del g.ndata['Q_h'] + del g.ndata['K_h'] + del g.edata['E'] + + if self.full_graph: + del g.ndata['Q_2h'] + del g.ndata['K_2h'] + del g.edata['E_2'] + + return h_out + + +class GraphiT_GT_Layer(nn.Module): + """ + Param: + """ + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0, + layer_norm=False, batch_norm=True, residual=True, adaptive_edge_PE=False, use_bias=False): + super().__init__() + + self.in_channels = in_dim + self.out_channels = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, adaptive_edge_PE, attention_for="h") + + self.O_h = nn.Linear(out_dim, out_dim) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm1_h = nn.BatchNorm1d(out_dim) + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2) + self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim) + + + def forward(self, g, h, p, e, snorm_n): + h_in1 = h # for first residual connection + + # [START] For calculation of h ----------------------------------------------------------------- + + # multi-head attention out + h_attn_out = self.attention_h(g, h, e) + + #Concat multi-head outputs + h = h_attn_out.view(-1, self.out_channels) + + h = F.dropout(h, self.dropout, training=self.training) + + h = self.O_h(h) + + if self.residual: + h = h_in1 + h # residual connection + + # # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm1_h(h) + + if self.batch_norm: + h = self.batch_norm1_h(h) + + h_in2 = h # for second residual connection + + # FFN for h + h = self.FFN_h_layer1(h) + h = F.relu(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + h = h_in2 + h # residual connection + + # # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + # [END] For calculation of h ----------------------------------------------------------------- + + + return h, None + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual) \ No newline at end of file diff --git a/layers/graphit_gt_lspe_layer.py b/layers/graphit_gt_lspe_layer.py new file mode 100644 index 0000000..9c81a75 --- /dev/null +++ b/layers/graphit_gt_lspe_layer.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import dgl.function as fn +import numpy as np + +""" + GraphiT-GT-LSPE: GraphiT-GT with LSPE + +""" + +""" + Util functions +""" +def src_dot_dst(src_field, dst_field, out_field): + def func(edges): + return {out_field: (edges.src[src_field] * edges.dst[dst_field])} + return func + + +def scaling(field, scale_constant): + def func(edges): + return {field: ((edges.data[field]) / scale_constant)} + return func + +# Improving implicit attention scores with explicit edge features, if available +def imp_exp_attn(implicit_attn, explicit_edge): + """ + implicit_attn: the output of K Q + explicit_edge: the explicit edge features + """ + def func(edges): + return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])} + return func + + +def exp(field): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))} + return func + +def adaptive_edge_PE(field, adaptive_weight): + def func(edges): + # initial shape was: adaptive_weight: [edges,1]; data: [edges, num_heads, 1] + # repeating adaptive_weight to have: [edges, num_heads, 1] + edges.data['tmp'] = edges.data[adaptive_weight].repeat(1, edges.data[field].shape[1]).unsqueeze(-1) + return {'score_soft': edges.data['tmp'] * edges.data[field]} + return func + + +""" + Single Attention Head +""" + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, adaptive_edge_PE, attention_for): + super().__init__() + + + self.out_dim = out_dim + self.num_heads = num_heads + self.gamma = gamma + self.full_graph=full_graph + self.attention_for = attention_for + self.adaptive_edge_PE = adaptive_edge_PE + + if self.attention_for == "h": # attention module for h has input h = [h,p], so 2*in_dim for Q,K,V + if use_bias: + self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + + elif self.attention_for == "p": # attention module for p + if use_bias: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + def propagate_attention(self, g): + + + if self.full_graph: + real_ids = torch.nonzero(g.edata['real']).squeeze() + fake_ids = torch.nonzero(g.edata['real']==0).squeeze() + + else: + real_ids = g.edges(form='eid') + + g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids) + + if self.full_graph: + g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids) + + + # scale scores by sqrt(d) + g.apply_edges(scaling('score', np.sqrt(self.out_dim))) + + # Use available edge features to modify the scores for edges + g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids) + + if self.full_graph: + g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids) + + g.apply_edges(exp('score')) + + # Adaptive weighting with k_RW_eij + # Only applicable to full graph, For NOW + if self.adaptive_edge_PE and self.full_graph: + g.apply_edges(adaptive_edge_PE('score_soft', 'k_RW')) + del g.edata['tmp'] + + # Send weighted values to target nodes + eids = g.edges() + g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV')) + g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z')) + + + def forward(self, g, h, p, e): + if self.attention_for == "h": + h = torch.cat((h, p), -1) + elif self.attention_for == "p": + h = p + + Q_h = self.Q(h) + K_h = self.K(h) + E = self.E(e) + + if self.full_graph: + Q_2h = self.Q_2(h) + K_2h = self.K_2(h) + E_2 = self.E_2(e) + + V_h = self.V(h) + + + # Reshaping into [num_nodes, num_heads, feat_dim] to + # get projections for multi-head attention + g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim) + g.edata['E'] = E.view(-1, self.num_heads, self.out_dim) + + + if self.full_graph: + g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim) + g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim) + + g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim) + + self.propagate_attention(g) + + h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) + + del g.ndata['wV'] + del g.ndata['z'] + del g.ndata['Q_h'] + del g.ndata['K_h'] + del g.edata['E'] + + if self.full_graph: + del g.ndata['Q_2h'] + del g.ndata['K_2h'] + del g.edata['E_2'] + + return h_out + + +class GraphiT_GT_LSPE_Layer(nn.Module): + """ + Param: + """ + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0, + layer_norm=False, batch_norm=True, residual=True, adaptive_edge_PE=False, use_bias=False): + super().__init__() + + self.in_channels = in_dim + self.out_channels = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, adaptive_edge_PE, attention_for="h") + self.attention_p = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, adaptive_edge_PE, attention_for="p") + + self.O_h = nn.Linear(out_dim, out_dim) + self.O_p = nn.Linear(out_dim, out_dim) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm1_h = nn.BatchNorm1d(out_dim) + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2) + self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim) + + + def forward(self, g, h, p, e, snorm_n): + h_in1 = h # for first residual connection + p_in1 = p # for first residual connection + + # [START] For calculation of h ----------------------------------------------------------------- + + # multi-head attention out + h_attn_out = self.attention_h(g, h, p, e) + + #Concat multi-head outputs + h = h_attn_out.view(-1, self.out_channels) + + h = F.dropout(h, self.dropout, training=self.training) + + h = self.O_h(h) + + if self.residual: + h = h_in1 + h # residual connection + + # # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm1_h(h) + + if self.batch_norm: + h = self.batch_norm1_h(h) + + h_in2 = h # for second residual connection + + # FFN for h + h = self.FFN_h_layer1(h) + h = F.relu(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + h = h_in2 + h # residual connection + + # # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + # [END] For calculation of h ----------------------------------------------------------------- + + + # [START] For calculation of p ----------------------------------------------------------------- + + # multi-head attention out + p_attn_out = self.attention_p(g, None, p, e) + + #Concat multi-head outputs + p = p_attn_out.view(-1, self.out_channels) + + p = F.dropout(p, self.dropout, training=self.training) + + p = self.O_p(p) + + p = torch.tanh(p) + + if self.residual: + p = p_in1 + p # residual connection + + # [END] For calculation of p ----------------------------------------------------------------- + + return h, p + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual) \ No newline at end of file diff --git a/layers/mlp_readout_layer.py b/layers/mlp_readout_layer.py new file mode 100644 index 0000000..20a4463 --- /dev/null +++ b/layers/mlp_readout_layer.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" + MLP Layer used after graph vector representation +""" + +class MLPReadout(nn.Module): + + def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers + super().__init__() + list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ] + list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True )) + self.FC_layers = nn.ModuleList(list_FC_layers) + self.L = L + + def forward(self, x): + y = x + for l in range(self.L): + y = self.FC_layers[l](y) + y = F.relu(y) + y = self.FC_layers[self.L](y) + return y + + + +class MLPReadout2(nn.Module): + + def __init__(self, input_dim, output_dim, dropout_2=0.0, L=2): # L=nb_hidden_layers + super().__init__() + list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)] + list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True)) + self.FC_layers = nn.ModuleList(list_FC_layers) + self.L = L + self.dropout_2 = dropout_2 + + def forward(self, x): + y = x + for l in range(self.L): + y = F.dropout(y, self.dropout_2, training=self.training) + y = self.FC_layers[l](y) + y = F.relu(y) + y = self.FC_layers[self.L](y) + return y \ No newline at end of file diff --git a/layers/pna_layer.py b/layers/pna_layer.py new file mode 100644 index 0000000..1140e39 --- /dev/null +++ b/layers/pna_layer.py @@ -0,0 +1,270 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl.function as fn +import dgl + +from .pna_utils import AGGREGATORS, SCALERS, MLP, FCLayer + +""" + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 +""" + + +class PNATower(nn.Module): + def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, + pretrans_layers, posttrans_layers, edge_features, edge_dim): + super().__init__() + self.dropout = dropout + self.graph_norm = graph_norm + self.batch_norm = batch_norm + self.edge_features = edge_features + + self.batchnorm_h = nn.BatchNorm1d(out_dim) + self.aggregators = aggregators + self.scalers = scalers + self.pretrans_h = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none') + + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + + self.avg_d = avg_d + + def pretrans_edges(self, edges): + if self.edge_features: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1) + else: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1) + + return {'e_for_h': self.pretrans_h(z2_for_h)} + + # Message func for h + def message_func_for_h(self, edges): + return {'e_for_h': edges.data['e_for_h']} + + # Reduce func for h + def reduce_func_for_h(self, nodes): + h = nodes.mailbox['e_for_h'] + D = h.shape[-2] + h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'h': h} + + def forward(self, g, h, e, snorm_n): + g.ndata['h'] = h + + if self.edge_features: # add the edges information only if edge_features = True + g.edata['ef'] = e + + # pretransformation + g.apply_edges(self.pretrans_edges) + + # aggregation for h + g.update_all(self.message_func_for_h, self.reduce_func_for_h) + h = torch.cat([h, g.ndata['h']], dim=1) + + # posttransformation + h = self.posttrans_h(h) + + # graph and batch normalization + if self.graph_norm: + h = h * snorm_n + + if self.batch_norm: + h = self.batchnorm_h(h) + + h = F.dropout(h, self.dropout, training=self.training) + + return h + + +class PNALayer(nn.Module): + + def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, graph_norm, batch_norm, towers=1, + pretrans_layers=1, posttrans_layers=1, divide_input=True, residual=False, edge_features=False, + edge_dim=0): + """ + :param in_dim: size of the input per node + :param out_dim: size of the output per node + :param aggregators: set of aggregation function identifiers + :param scalers: set of scaling functions identifiers + :param avg_d: average degree of nodes in the training set, used by scalers to normalize + :param dropout: dropout used + :param graph_norm: whether to use graph normalisation + :param batch_norm: whether to use batch normalisation + :param towers: number of towers to use + :param pretrans_layers: number of layers in the transformation before the aggregation + :param posttrans_layers: number of layers in the transformation after the aggregation + :param divide_input: whether the input features should be split between towers or not + :param residual: whether to add a residual connection + :param edge_features: whether to use the edge features + :param edge_dim: size of the edge features + """ + super().__init__() + assert ((not divide_input) or in_dim % towers == 0), "if divide_input is set the number of towers has to divide in_dim" + assert (out_dim % towers == 0), "the number of towers has to divide the out_dim" + assert avg_d is not None + + # retrieve the aggregators and scalers functions + aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] + scalers = [SCALERS[scale] for scale in scalers.split()] + + self.divide_input = divide_input + self.input_tower = in_dim // towers if divide_input else in_dim + self.output_tower = out_dim // towers + self.in_dim = in_dim + self.out_dim = out_dim + self.edge_features = edge_features + self.residual = residual + if in_dim != out_dim: + self.residual = False + + # convolution + self.towers = nn.ModuleList() + for _ in range(towers): + self.towers.append(PNATower(in_dim=self.input_tower, out_dim=self.output_tower, aggregators=aggregators, + scalers=scalers, avg_d=avg_d, pretrans_layers=pretrans_layers, + posttrans_layers=posttrans_layers, batch_norm=batch_norm, dropout=dropout, + graph_norm=graph_norm, edge_features=edge_features, edge_dim=edge_dim)) + # mixing network + self.mixing_network_h = FCLayer(out_dim, out_dim, activation='LeakyReLU') + + def forward(self, g, h, p, e, snorm_n): + h_in = h # for residual connection + + if self.divide_input: + tower_outs = [tower(g, + h[:, n_tower * self.input_tower: (n_tower + 1) * self.input_tower], + e, + snorm_n) for n_tower, tower in enumerate(self.towers)] + h_tower_outs = tower_outs + h_cat = torch.cat(h_tower_outs, dim=1) + else: + tower_outs = [tower(g, h, p, e, snorm_n) for tower in self.towers] + h_tower_outs = tower_outs + h_cat = torch.cat(h_tower_outs, dim=1) + + h_out = self.mixing_network_h(h_cat) + + if self.residual: + h_out = h_in + h_out # residual connection + + return h_out, None + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) + + + + + +# This layer file below has no towers +# and is similar to DGNLayerComplex used for best PNA score on MOLPCBA +# implemented here https://github.com/Saro00/DGN/blob/master/models/dgl/dgn_layer.py + +class PNANoTowersLayer(nn.Module): + def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, + pretrans_layers, posttrans_layers, residual, edge_features, edge_dim=0, use_lapeig_loss=False): + super().__init__() + self.dropout = dropout + self.graph_norm = graph_norm + self.batch_norm = batch_norm + self.edge_features = edge_features + self.in_dim = in_dim + self.out_dim = out_dim + self.residual = residual + if in_dim != out_dim: + self.residual = False + + self.batchnorm_h = nn.BatchNorm1d(out_dim) + + # retrieve the aggregators and scalers functions + aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] + scalers = [SCALERS[scale] for scale in scalers.split()] + + self.aggregators = aggregators + self.scalers = scalers + + if self.edge_features: + self.pretrans_h = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none') + + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + else: + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers)) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + + self.avg_d = avg_d + + def pretrans_edges(self, edges): + if self.edge_features: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1) + else: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1) + + return {'e_for_h': self.pretrans_h(z2_for_h)} + + # Message func for h + def message_func_for_h(self, edges): + return {'e_for_h': edges.data['e_for_h']} + + # Reduce func for h + def reduce_func_for_h(self, nodes): + if self.edge_features: + h = nodes.mailbox['e_for_h'] + else: + h = nodes.mailbox['m_h'] + D = h.shape[-2] + h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) + if len(self.scalers) > 1: + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'h': h} + + def forward(self, g, h, p, e, snorm_n): + + h = F.dropout(h, self.dropout, training=self.training) + + h_in = h # for residual connection + + g.ndata['h'] = h + + if self.edge_features: # add the edges information only if edge_features = True + g.edata['ef'] = e + + if self.edge_features: + # pretransformation + g.apply_edges(self.pretrans_edges) + + if self.edge_features: + # aggregation for h + g.update_all(self.message_func_for_h, self.reduce_func_for_h) + h = torch.cat([h, g.ndata['h']], dim=1) + else: + # aggregation for h + g.update_all(fn.copy_u('h', 'm_h'), self.reduce_func_for_h) + h = g.ndata['h'] + + # posttransformation + h = self.posttrans_h(h) + + # graph and batch normalization + if self.graph_norm and self.edge_features: + h = h * snorm_n + + if self.batch_norm: + h = self.batchnorm_h(h) + + h = F.relu(h) + + if self.residual: + h = h_in + h # residual connection + + return h, None + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) + diff --git a/layers/pna_lspe_layer.py b/layers/pna_lspe_layer.py new file mode 100644 index 0000000..4fd35ba --- /dev/null +++ b/layers/pna_lspe_layer.py @@ -0,0 +1,350 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl.function as fn +import dgl + +from .pna_utils import AGGREGATORS, SCALERS, MLP, FCLayer + +""" + PNALSPE: PNA with LSPE + + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 +""" + + +class PNATower(nn.Module): + def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, + pretrans_layers, posttrans_layers, edge_features, edge_dim): + super().__init__() + self.dropout = dropout + self.graph_norm = graph_norm + self.batch_norm = batch_norm + self.edge_features = edge_features + + self.batchnorm_h = nn.BatchNorm1d(out_dim) + self.aggregators = aggregators + self.scalers = scalers + self.pretrans_h = MLP(in_size=2 * 2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none') + self.pretrans_p = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='tanh', last_activation='none') + + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 2) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none') + + self.avg_d = avg_d + + def pretrans_edges(self, edges): + if self.edge_features: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1) + z2_for_p = torch.cat([edges.src['p'], edges.dst['p'], edges.data['ef']], dim=1) + else: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1) + z2_for_p = torch.cat([edges.src['p'], edges.dst['p']], dim=1) + + return {'e_for_h': self.pretrans_h(z2_for_h), 'e_for_p': self.pretrans_p(z2_for_p)} + + # Message func for h + def message_func_for_h(self, edges): + return {'e_for_h': edges.data['e_for_h']} + + # Reduce func for h + def reduce_func_for_h(self, nodes): + h = nodes.mailbox['e_for_h'] + D = h.shape[-2] + h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'h': h} + + # Message func for p + def message_func_for_p(self, edges): + return {'e_for_p': edges.data['e_for_p']} + + # Reduce func for p + def reduce_func_for_p(self, nodes): + p = nodes.mailbox['e_for_p'] + D = p.shape[-2] + p = torch.cat([aggregate(p) for aggregate in self.aggregators], dim=1) + p = torch.cat([scale(p, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'p': p} + + def forward(self, g, h, p, e, snorm_n): + g.ndata['h'] = h + g.ndata['p'] = p + + if self.edge_features: # add the edges information only if edge_features = True + g.edata['ef'] = e + + # pretransformation + g.apply_edges(self.pretrans_edges) + + # aggregation for h + g.update_all(self.message_func_for_h, self.reduce_func_for_h) + h = torch.cat([h, g.ndata['h']], dim=1) + + # aggregation for p + g.update_all(self.message_func_for_p, self.reduce_func_for_p) + p = torch.cat([p, g.ndata['p']], dim=1) + + # posttransformation + h = self.posttrans_h(h) + p = self.posttrans_p(p) + + # graph and batch normalization + if self.graph_norm: + h = h * snorm_n + + if self.batch_norm: + h = self.batchnorm_h(h) + + h = F.dropout(h, self.dropout, training=self.training) + p = F.dropout(p, self.dropout, training=self.training) + + return h, p + + +class PNALSPELayer(nn.Module): + + def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, graph_norm, batch_norm, towers=1, + pretrans_layers=1, posttrans_layers=1, divide_input=True, residual=False, edge_features=False, + edge_dim=0): + """ + :param in_dim: size of the input per node + :param out_dim: size of the output per node + :param aggregators: set of aggregation function identifiers + :param scalers: set of scaling functions identifiers + :param avg_d: average degree of nodes in the training set, used by scalers to normalize + :param dropout: dropout used + :param graph_norm: whether to use graph normalisation + :param batch_norm: whether to use batch normalisation + :param towers: number of towers to use + :param pretrans_layers: number of layers in the transformation before the aggregation + :param posttrans_layers: number of layers in the transformation after the aggregation + :param divide_input: whether the input features should be split between towers or not + :param residual: whether to add a residual connection + :param edge_features: whether to use the edge features + :param edge_dim: size of the edge features + """ + super().__init__() + assert ((not divide_input) or in_dim % towers == 0), "if divide_input is set the number of towers has to divide in_dim" + assert (out_dim % towers == 0), "the number of towers has to divide the out_dim" + assert avg_d is not None + + # retrieve the aggregators and scalers functions + aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] + scalers = [SCALERS[scale] for scale in scalers.split()] + + self.divide_input = divide_input + self.input_tower = in_dim // towers if divide_input else in_dim + self.output_tower = out_dim // towers + self.in_dim = in_dim + self.out_dim = out_dim + self.edge_features = edge_features + self.residual = residual + if in_dim != out_dim: + self.residual = False + + # convolution + self.towers = nn.ModuleList() + for _ in range(towers): + self.towers.append(PNATower(in_dim=self.input_tower, out_dim=self.output_tower, aggregators=aggregators, + scalers=scalers, avg_d=avg_d, pretrans_layers=pretrans_layers, + posttrans_layers=posttrans_layers, batch_norm=batch_norm, dropout=dropout, + graph_norm=graph_norm, edge_features=edge_features, edge_dim=edge_dim)) + # mixing network + self.mixing_network_h = FCLayer(out_dim, out_dim, activation='LeakyReLU') + self.mixing_network_p = FCLayer(out_dim, out_dim, activation='tanh') + + def forward(self, g, h, p, e, snorm_n): + h_in = h # for residual connection + p_in = p # for residual connection + + # Concating p to h, as in PEGNN + h = torch.cat((h, p), -1) + + if self.divide_input: + tower_outs = [tower(g, + h[:, n_tower * 2 * self.input_tower: (n_tower + 1) * 2 * self.input_tower], + p[:, n_tower * self.input_tower: (n_tower + 1) * self.input_tower], + e, + snorm_n) for n_tower, tower in enumerate(self.towers)] + h_tower_outs, p_tower_outs = map(list,zip(*tower_outs)) + h_cat = torch.cat(h_tower_outs, dim=1) + p_cat = torch.cat(p_tower_outs, dim=1) + else: + tower_outs = [tower(g, h, p, e, snorm_n) for tower in self.towers] + h_tower_outs, p_tower_outs = map(list,zip(*tower_outs)) + h_cat = torch.cat(h_tower_outs, dim=1) + p_cat = torch.cat(p_tower_outs, dim=1) + + h_out = self.mixing_network_h(h_cat) + p_out = self.mixing_network_p(p_cat) + + + if self.residual: + h_out = h_in + h_out # residual connection + p_out = p_in + p_out # residual connection + + return h_out, p_out + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) + + + + + +# This layer file below has no towers +# and is similar to DGNLayerComplex used for best PNA score on MOLPCBA +# implemented here https://github.com/Saro00/DGN/blob/master/models/dgl/dgn_layer.py + +class PNANoTowersLSPELayer(nn.Module): + def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, + pretrans_layers, posttrans_layers, residual, edge_features, edge_dim=0, use_lapeig_loss=False): + super().__init__() + self.dropout = dropout + self.graph_norm = graph_norm + self.batch_norm = batch_norm + self.edge_features = edge_features + self.in_dim = in_dim + self.out_dim = out_dim + self.residual = residual + if in_dim != out_dim: + self.residual = False + + self.use_lapeig_loss = use_lapeig_loss + + self.batchnorm_h = nn.BatchNorm1d(out_dim) + + # retrieve the aggregators and scalers functions + aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] + scalers = [SCALERS[scale] for scale in scalers.split()] + + self.aggregators = aggregators + self.scalers = scalers + + if self.edge_features: + self.pretrans_h = MLP(in_size=2 * 2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none') + self.pretrans_p = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, + out_size=in_dim, layers=pretrans_layers, mid_activation='tanh', last_activation='none') + + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers) + 2) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none') + else: + self.posttrans_h = MLP(in_size=(len(aggregators) * len(scalers)) * 2 * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') + self.posttrans_p = MLP(in_size=(len(aggregators) * len(scalers)) * in_dim, hidden_size=out_dim, + out_size=out_dim, layers=posttrans_layers, mid_activation='tanh', last_activation='none') + + self.avg_d = avg_d + + def pretrans_edges(self, edges): + if self.edge_features: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1) + z2_for_p = torch.cat([edges.src['p'], edges.dst['p'], edges.data['ef']], dim=1) + else: + z2_for_h = torch.cat([edges.src['h'], edges.dst['h']], dim=1) + z2_for_p = torch.cat([edges.src['p'], edges.dst['p']], dim=1) + + return {'e_for_h': self.pretrans_h(z2_for_h), 'e_for_p': self.pretrans_p(z2_for_p)} + + # Message func for h + def message_func_for_h(self, edges): + return {'e_for_h': edges.data['e_for_h']} + + # Reduce func for h + def reduce_func_for_h(self, nodes): + if self.edge_features: + h = nodes.mailbox['e_for_h'] + else: + h = nodes.mailbox['m_h'] + D = h.shape[-2] + h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) + if len(self.scalers) > 1: + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'h': h} + + # Message func for p + def message_func_for_p(self, edges): + return {'e_for_p': edges.data['e_for_p']} + + # Reduce func for p + def reduce_func_for_p(self, nodes): + if self.edge_features: + p = nodes.mailbox['e_for_p'] + else: + p = nodes.mailbox['m_p'] + D = p.shape[-2] + p = torch.cat([aggregate(p) for aggregate in self.aggregators], dim=1) + if len(self.scalers) > 1: + p = torch.cat([scale(p, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + return {'p': p} + + def forward(self, g, h, p, e, snorm_n): + + h = F.dropout(h, self.dropout, training=self.training) + p = F.dropout(p, self.dropout, training=self.training) + + h_in = h # for residual connection + p_in = p # for residual connection + + # Concating p to h, as in PEGNN + h = torch.cat((h, p), -1) + + g.ndata['h'] = h + g.ndata['p'] = p + + if self.edge_features: # add the edges information only if edge_features = True + g.edata['ef'] = e + + if self.edge_features: + # pretransformation + g.apply_edges(self.pretrans_edges) + + if self.edge_features: + # aggregation for h + g.update_all(self.message_func_for_h, self.reduce_func_for_h) + h = torch.cat([h, g.ndata['h']], dim=1) + + # aggregation for p + g.update_all(self.message_func_for_p, self.reduce_func_for_p) + p = torch.cat([p, g.ndata['p']], dim=1) + else: + # aggregation for h + g.update_all(fn.copy_u('h', 'm_h'), self.reduce_func_for_h) + h = g.ndata['h'] + + # aggregation for p + g.update_all(fn.copy_u('p', 'm_p'), self.reduce_func_for_p) + p = g.ndata['p'] + + # posttransformation + h = self.posttrans_h(h) + p = self.posttrans_p(p) + + # graph and batch normalization + if self.graph_norm and self.edge_features: + h = h * snorm_n + + if self.batch_norm: + h = self.batchnorm_h(h) + + h = F.relu(h) + p = torch.tanh(h) + + if self.residual: + h = h_in + h # residual connection + h = p_in + p # residual connection + + return h, p + + def __repr__(self): + return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) \ No newline at end of file diff --git a/layers/pna_utils.py b/layers/pna_utils.py new file mode 100644 index 0000000..61c9fd5 --- /dev/null +++ b/layers/pna_utils.py @@ -0,0 +1,407 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +# PNA Aggregators ------------------------------------------------------------------------------ + +EPS = 1e-5 + + +def aggregate_mean(h): + return torch.mean(h, dim=1) + + +def aggregate_max(h): + return torch.max(h, dim=1)[0] + + +def aggregate_min(h): + return torch.min(h, dim=1)[0] + + +def aggregate_std(h): + return torch.sqrt(aggregate_var(h) + EPS) + + +def aggregate_var(h): + h_mean_squares = torch.mean(h * h, dim=-2) + h_mean = torch.mean(h, dim=-2) + var = torch.relu(h_mean_squares - h_mean * h_mean) + return var + + +def aggregate_moment(h, n=3): + # for each node (E[(X-E[X])^n])^{1/n} + # EPS is added to the absolute value of expectation before taking the nth root for stability + h_mean = torch.mean(h, dim=1, keepdim=True) + h_n = torch.mean(torch.pow(h - h_mean, n)) + rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1. / n) + return rooted_h_n + + +def aggregate_moment_3(h): + return aggregate_moment(h, n=3) + + +def aggregate_moment_4(h): + return aggregate_moment(h, n=4) + + +def aggregate_moment_5(h): + return aggregate_moment(h, n=5) + + +def aggregate_sum(h): + return torch.sum(h, dim=1) + + +AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min, + 'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3, 'moment4': aggregate_moment_4, + 'moment5': aggregate_moment_5} + + + + +# PNA Scalers --------------------------------------------------------------------------------- + + +# each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and +# avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output + +def scale_identity(h, D=None, avg_d=None): + return h + + +def scale_amplification(h, D, avg_d): + # log(D + 1) / d * h where d is the average of the ``log(D + 1)`` in the training set + return h * (np.log(D + 1) / avg_d["log"]) + + +def scale_attenuation(h, D, avg_d): + # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set + return h * (avg_d["log"] / np.log(D + 1)) + + +SCALERS = {'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation} + + + + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +SUPPORTED_ACTIVATION_MAP = {'ReLU', 'Sigmoid', 'Tanh', 'ELU', 'SELU', 'GLU', 'LeakyReLU', 'Softplus', 'None'} + + +def get_activation(activation): + """ returns the activation function represented by the input string """ + if activation and callable(activation): + # activation is already a function + return activation + # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation + activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()] + assert len(activation) == 1 and isinstance(activation[0], str), 'Unhandled activation function' + activation = activation[0] + if activation.lower() == 'none': + return None + return vars(torch.nn.modules.activation)[activation]() + + +class Set2Set(torch.nn.Module): + r""" + Set2Set global pooling operator from the `"Order Matters: Sequence to sequence for sets" + `_ paper. This pooling layer performs the following operation + + .. math:: + \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) + + \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) + + \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i + + \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, + + where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice + the dimensionality as the input. + + Arguments + --------- + input_dim: int + Size of each input sample. + hidden_dim: int, optional + the dim of set representation which corresponds to the input dim of the LSTM in Set2Set. + This is typically the sum of the input dim and the lstm output dim. If not provided, it will be set to :obj:`input_dim*2` + steps: int, optional + Number of iterations :math:`T`. If not provided, the number of nodes will be used. + num_layers : int, optional + Number of recurrent layers (e.g., :obj:`num_layers=2` would mean stacking two LSTMs together) + (Default, value = 1) + """ + + def __init__(self, nin, nhid=None, steps=None, num_layers=1, activation=None, device='cpu'): + super(Set2Set, self).__init__() + self.steps = steps + self.nin = nin + self.nhid = nin * 2 if nhid is None else nhid + if self.nhid <= self.nin: + raise ValueError('Set2Set hidden_dim should be larger than input_dim') + # the hidden is a concatenation of weighted sum of embedding and LSTM output + self.lstm_output_dim = self.nhid - self.nin + self.num_layers = num_layers + self.lstm = nn.LSTM(self.nhid, self.nin, num_layers=num_layers, batch_first=True).to(device) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + r""" + Applies the pooling on input tensor x + + Arguments + ---------- + x: torch.FloatTensor + Input tensor of size (B, N, D) + + Returns + ------- + x: `torch.FloatTensor` + Tensor resulting from the set2set pooling operation. + """ + + batch_size = x.shape[0] + n = self.steps or x.shape[1] + + h = (x.new_zeros((self.num_layers, batch_size, self.nin)), + x.new_zeros((self.num_layers, batch_size, self.nin))) + + q_star = x.new_zeros(batch_size, 1, self.nhid) + + for i in range(n): + # q: batch_size x 1 x input_dim + q, h = self.lstm(q_star, h) + # e: batch_size x n x 1 + e = torch.matmul(x, torch.transpose(q, 1, 2)) + a = self.softmax(e) + r = torch.sum(a * x, dim=1, keepdim=True) + q_star = torch.cat([q, r], dim=-1) + + return torch.squeeze(q_star, dim=1) + + +class FCLayer(nn.Module): + r""" + A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. + The order in which transformations are applied is: + + #. Dense Layer + #. Activation + #. Dropout (if applicable) + #. Batch Normalization (if applicable) + + Arguments + ---------- + in_size: int + Input dimension of the layer (the torch.nn.Linear) + out_size: int + Output dimension of the layer. + dropout: float, optional + The ratio of units to dropout. No dropout by default. + (Default value = 0.) + activation: str or callable, optional + Activation function to use. + (Default value = relu) + b_norm: bool, optional + Whether to use batch normalization + (Default value = False) + bias: bool, optional + Whether to enable bias in for the linear layer. + (Default value = True) + init_fn: callable, optional + Initialization function to use for the weight of the layer. Default is + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` with :math:`k=\frac{1}{ \text{in_size}}` + (Default value = None) + + Attributes + ---------- + dropout: int + The ratio of units to dropout. + b_norm: int + Whether to use batch normalization + linear: torch.nn.Linear + The linear layer + activation: the torch.nn.Module + The activation layer + init_fn: function + Initialization function used for the weight of the layer + in_size: int + Input dimension of the linear layer + out_size: int + Output dimension of the linear layer + """ + + def __init__(self, in_size, out_size, activation='relu', dropout=0., b_norm=False, bias=True, init_fn=None, + device='cpu'): + super(FCLayer, self).__init__() + + self.__params = locals() + del self.__params['__class__'] + del self.__params['self'] + self.in_size = in_size + self.out_size = out_size + self.bias = bias + self.linear = nn.Linear(in_size, out_size, bias=bias).to(device) + self.dropout = None + self.b_norm = None + if dropout: + self.dropout = nn.Dropout(p=dropout) + if b_norm: + self.b_norm = nn.BatchNorm1d(out_size).to(device) + self.activation = get_activation(activation) + self.init_fn = nn.init.xavier_uniform_ + + self.reset_parameters() + + def reset_parameters(self, init_fn=None): + init_fn = init_fn or self.init_fn + if init_fn is not None: + init_fn(self.linear.weight, 1 / self.in_size) + if self.bias: + self.linear.bias.data.zero_() + + def forward(self, x): + h = self.linear(x) + if self.activation is not None: + h = self.activation(h) + if self.dropout is not None: + h = self.dropout(h) + if self.b_norm is not None: + if h.shape[1] != self.out_size: + h = self.b_norm(h.transpose(1, 2)).transpose(1, 2) + else: + h = self.b_norm(h) + return h + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_size) + ' -> ' \ + + str(self.out_size) + ')' + + +class MLP(nn.Module): + """ + Simple multi-layer perceptron, built of a series of FCLayers + """ + + def __init__(self, in_size, hidden_size, out_size, layers, mid_activation='relu', last_activation='none', + dropout=0., mid_b_norm=False, last_b_norm=False, device='cpu'): + super(MLP, self).__init__() + + self.in_size = in_size + self.hidden_size = hidden_size + self.out_size = out_size + + self.fully_connected = nn.ModuleList() + if layers <= 1: + self.fully_connected.append(FCLayer(in_size, out_size, activation=last_activation, b_norm=last_b_norm, + device=device, dropout=dropout)) + else: + self.fully_connected.append(FCLayer(in_size, hidden_size, activation=mid_activation, b_norm=mid_b_norm, + device=device, dropout=dropout)) + for _ in range(layers - 2): + self.fully_connected.append(FCLayer(hidden_size, hidden_size, activation=mid_activation, + b_norm=mid_b_norm, device=device, dropout=dropout)) + self.fully_connected.append(FCLayer(hidden_size, out_size, activation=last_activation, b_norm=last_b_norm, + device=device, dropout=dropout)) + + def forward(self, x): + for fc in self.fully_connected: + x = fc(x) + return x + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_size) + ' -> ' \ + + str(self.out_size) + ')' + + +class GRU(nn.Module): + """ + Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself + """ + + def __init__(self, input_size, hidden_size, device): + super(GRU, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device) + + def forward(self, x, y): + """ + :param x: shape: (B, N, Din) where Din <= input_size (difference is padded) + :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded) + :return: shape: (B, N, Dh) + """ + assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size) + + (B, N, _) = x.shape + x = x.reshape(1, B * N, -1).contiguous() + y = y.reshape(1, B * N, -1).contiguous() + + # padding if necessary + if x.shape[-1] < self.input_size: + x = F.pad(input=x, pad=[0, self.input_size - x.shape[-1]], mode='constant', value=0) + if y.shape[-1] < self.hidden_size: + y = F.pad(input=y, pad=[0, self.hidden_size - y.shape[-1]], mode='constant', value=0) + + x = self.gru(x, y)[1] + x = x.reshape(B, N, -1) + return x + + +class S2SReadout(nn.Module): + """ + Performs a Set2Set aggregation of all the graph nodes' features followed by a series of fully connected layers + """ + + def __init__(self, in_size, hidden_size, out_size, fc_layers=3, device='cpu', final_activation='relu'): + super(S2SReadout, self).__init__() + + # set2set aggregation + self.set2set = Set2Set(in_size, device=device) + + # fully connected layers + self.mlp = MLP(in_size=2 * in_size, hidden_size=hidden_size, out_size=out_size, layers=fc_layers, + mid_activation="relu", last_activation=final_activation, mid_b_norm=True, last_b_norm=False, + device=device) + + def forward(self, x): + x = self.set2set(x) + return self.mlp(x) + + + +class GRU(nn.Module): + """ + Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself + """ + + def __init__(self, input_size, hidden_size, device): + super(GRU, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device) + + def forward(self, x, y): + """ + :param x: shape: (B, N, Din) where Din <= input_size (difference is padded) + :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded) + :return: shape: (B, N, Dh) + """ + assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size) + x = x.unsqueeze(0) + y = y.unsqueeze(0) + x = self.gru(x, y)[1] + x = x.squeeze() + return x diff --git a/layers/san_gt_layer.py b/layers/san_gt_layer.py new file mode 100644 index 0000000..e71f1b5 --- /dev/null +++ b/layers/san_gt_layer.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import dgl.function as fn +import numpy as np + +""" + SAN-GT + +""" + +""" + Util functions +""" +def src_dot_dst(src_field, dst_field, out_field): + def func(edges): + return {out_field: (edges.src[src_field] * edges.dst[dst_field])} + return func + + +def scaling(field, scale_constant): + def func(edges): + return {field: ((edges.data[field]) / scale_constant)} + return func + +# Improving implicit attention scores with explicit edge features, if available +def imp_exp_attn(implicit_attn, explicit_edge): + """ + implicit_attn: the output of K Q + explicit_edge: the explicit edge features + """ + def func(edges): + return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])} + return func + +def exp_real(field, L): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)} + return func + + +def exp_fake(field, L): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': L*torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)} + return func + +def exp(field): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))} + return func + + +""" + Single Attention Head +""" + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, attention_for): + super().__init__() + + + self.out_dim = out_dim + self.num_heads = num_heads + self.full_graph=full_graph + self.attention_for = attention_for + self.gamma = gamma + + if self.attention_for == "h": + if use_bias: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + + def propagate_attention(self, g): + + + if self.full_graph: + real_ids = torch.nonzero(g.edata['real']).squeeze() + fake_ids = torch.nonzero(g.edata['real']==0).squeeze() + + else: + real_ids = g.edges(form='eid') + + g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids) + + if self.full_graph: + g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids) + + + # scale scores by sqrt(d) + g.apply_edges(scaling('score', np.sqrt(self.out_dim))) + + # Use available edge features to modify the scores for edges + g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids) + + if self.full_graph: + g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids) + + + if self.full_graph: + # softmax and scaling by gamma + L = torch.clamp(self.gamma, min=0.0, max=1.0) # Gamma \in [0,1] + g.apply_edges(exp_real('score', L), edges=real_ids) + g.apply_edges(exp_fake('score', L), edges=fake_ids) + + else: + g.apply_edges(exp('score'), edges=real_ids) + + # Send weighted values to target nodes + eids = g.edges() + g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV')) + g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z')) + + + def forward(self, g, h, e): + + Q_h = self.Q(h) + K_h = self.K(h) + E = self.E(e) + + if self.full_graph: + Q_2h = self.Q_2(h) + K_2h = self.K_2(h) + E_2 = self.E_2(e) + + V_h = self.V(h) + + + # Reshaping into [num_nodes, num_heads, feat_dim] to + # get projections for multi-head attention + g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim) + g.edata['E'] = E.view(-1, self.num_heads, self.out_dim) + + + if self.full_graph: + g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim) + g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim) + + g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim) + + self.propagate_attention(g) + + h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) + + return h_out + + +class SAN_GT_Layer(nn.Module): + """ + Param: + """ + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0, + layer_norm=False, batch_norm=True, residual=True, use_bias=False): + super().__init__() + + self.in_channels = in_dim + self.out_channels = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, attention_for="h") + + self.O_h = nn.Linear(out_dim, out_dim) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm1_h = nn.BatchNorm1d(out_dim) + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2) + self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim) + + + def forward(self, g, h, p, e, snorm_n): + h_in1 = h # for first residual connection + + # [START] For calculation of h ----------------------------------------------------------------- + + # multi-head attention out + h_attn_out = self.attention_h(g, h, e) + + #Concat multi-head outputs + h = h_attn_out.view(-1, self.out_channels) + + h = F.dropout(h, self.dropout, training=self.training) + + h = self.O_h(h) + + if self.residual: + h = h_in1 + h # residual connection + + # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm1_h(h) + + if self.batch_norm: + h = self.batch_norm1_h(h) + + h_in2 = h # for second residual connection + + # FFN for h + h = self.FFN_h_layer1(h) + h = F.relu(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + h = h_in2 + h # residual connection + + # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + # [END] For calculation of h ----------------------------------------------------------------- + + return h, None + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual) \ No newline at end of file diff --git a/layers/san_gt_lspe_layer.py b/layers/san_gt_lspe_layer.py new file mode 100644 index 0000000..e18361c --- /dev/null +++ b/layers/san_gt_lspe_layer.py @@ -0,0 +1,318 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import dgl.function as fn +import numpy as np + +""" + SAN-GT-LSPE: SAN-GT with LSPE + +""" + +""" + Util functions +""" +def src_dot_dst(src_field, dst_field, out_field): + def func(edges): + return {out_field: (edges.src[src_field] * edges.dst[dst_field])} + return func + + +def scaling(field, scale_constant): + def func(edges): + return {field: ((edges.data[field]) / scale_constant)} + return func + +# Improving implicit attention scores with explicit edge features, if available +def imp_exp_attn(implicit_attn, explicit_edge): + """ + implicit_attn: the output of K Q + explicit_edge: the explicit edge features + """ + def func(edges): + return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])} + return func + +def exp_real(field, L): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)} + return func + + +def exp_fake(field, L): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': L*torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))/(L+1)} + return func + +def exp(field): + def func(edges): + # clamp for softmax numerical stability + return {'score_soft': torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))} + return func + + +""" + Single Attention Head +""" + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, use_bias, attention_for): + super().__init__() + + + self.out_dim = out_dim + self.num_heads = num_heads + self.full_graph=full_graph + self.attention_for = attention_for + self.gamma = gamma + + if self.attention_for == "h": # attention module for h has input h = [h,p], so 2*in_dim for Q,K,V + if use_bias: + self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim*2, out_dim * num_heads, bias=False) + + elif self.attention_for == "p": # attention module for p + if use_bias: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) + + else: + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + if self.full_graph: + self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) + + def propagate_attention(self, g): + + if self.full_graph: + real_ids = torch.nonzero(g.edata['real']).squeeze() + fake_ids = torch.nonzero(g.edata['real']==0).squeeze() + + else: + real_ids = g.edges(form='eid') + + g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'), edges=real_ids) + + if self.full_graph: + g.apply_edges(src_dot_dst('K_2h', 'Q_2h', 'score'), edges=fake_ids) + + + # scale scores by sqrt(d) + g.apply_edges(scaling('score', np.sqrt(self.out_dim))) + + # Use available edge features to modify the scores for edges + g.apply_edges(imp_exp_attn('score', 'E'), edges=real_ids) + + if self.full_graph: + g.apply_edges(imp_exp_attn('score', 'E_2'), edges=fake_ids) + + + if self.full_graph: + # softmax and scaling by gamma + L = torch.clamp(self.gamma, min=0.0, max=1.0) # Gamma \in [0,1] + g.apply_edges(exp_real('score', L), edges=real_ids) + g.apply_edges(exp_fake('score', L), edges=fake_ids) + + else: + g.apply_edges(exp('score'), edges=real_ids) + + # Send weighted values to target nodes + eids = g.edges() + g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score_soft', 'V_h'), fn.sum('V_h', 'wV')) + g.send_and_recv(eids, fn.copy_edge('score_soft', 'score_soft'), fn.sum('score_soft', 'z')) + + + def forward(self, g, h, p, e): + if self.attention_for == "h": + h = torch.cat((h, p), -1) + elif self.attention_for == "p": + h = p + + Q_h = self.Q(h) + K_h = self.K(h) + E = self.E(e) + + if self.full_graph: + Q_2h = self.Q_2(h) + K_2h = self.K_2(h) + E_2 = self.E_2(e) + + V_h = self.V(h) + + + # Reshaping into [num_nodes, num_heads, feat_dim] to + # get projections for multi-head attention + g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim) + g.edata['E'] = E.view(-1, self.num_heads, self.out_dim) + + + if self.full_graph: + g.ndata['Q_2h'] = Q_2h.view(-1, self.num_heads, self.out_dim) + g.ndata['K_2h'] = K_2h.view(-1, self.num_heads, self.out_dim) + g.edata['E_2'] = E_2.view(-1, self.num_heads, self.out_dim) + + g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim) + + self.propagate_attention(g) + + h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) + + return h_out + + +class SAN_GT_LSPE_Layer(nn.Module): + """ + Param: + """ + def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph, dropout=0.0, + layer_norm=False, batch_norm=True, residual=True, use_bias=False): + super().__init__() + + self.in_channels = in_dim + self.out_channels = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + self.attention_h = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, attention_for="h") + self.attention_p = MultiHeadAttentionLayer(gamma, in_dim, out_dim//num_heads, num_heads, + full_graph, use_bias, attention_for="p") + + self.O_h = nn.Linear(out_dim, out_dim) + self.O_p = nn.Linear(out_dim, out_dim) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm1_h = nn.BatchNorm1d(out_dim) + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2) + self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim) + + + def forward(self, g, h, p, e, snorm_n): + h_in1 = h # for first residual connection + p_in1 = p # for first residual connection + + # [START] For calculation of h ----------------------------------------------------------------- + + # multi-head attention out + h_attn_out = self.attention_h(g, h, p, e) + + #Concat multi-head outputs + h = h_attn_out.view(-1, self.out_channels) + + h = F.dropout(h, self.dropout, training=self.training) + + h = self.O_h(h) + + if self.residual: + h = h_in1 + h # residual connection + + # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm1_h(h) + + if self.batch_norm: + h = self.batch_norm1_h(h) + + h_in2 = h # for second residual connection + + # FFN for h + h = self.FFN_h_layer1(h) + h = F.relu(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + h = h_in2 + h # residual connection + + # GN from benchmarking-gnns-v1 + # h = h * snorm_n + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + # [END] For calculation of h ----------------------------------------------------------------- + + + # [START] For calculation of p ----------------------------------------------------------------- + + # multi-head attention out + p_attn_out = self.attention_p(g, None, p, e) + + #Concat multi-head outputs + p = p_attn_out.view(-1, self.out_channels) + + p = F.dropout(p, self.dropout, training=self.training) + + p = self.O_p(p) + + p = torch.tanh(p) + + if self.residual: + p = p_in1 + p # residual connection + + # [END] For calculation of p ----------------------------------------------------------------- + + return h, p + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual) \ No newline at end of file diff --git a/main_OGBMOL_graph_classification.py b/main_OGBMOL_graph_classification.py new file mode 100644 index 0000000..e12eedc --- /dev/null +++ b/main_OGBMOL_graph_classification.py @@ -0,0 +1,504 @@ + + + + + +""" + IMPORTING LIBS +""" +import dgl + +import numpy as np +import os +import socket +import time +import random +import glob +import argparse, json + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.optim as optim +from torch.utils.data import DataLoader + +from tensorboardX import SummaryWriter +from tqdm import tqdm + +import matplotlib +import matplotlib.pyplot as plt + +class DotDict(dict): + def __init__(self, **kwds): + self.update(kwds) + self.__dict__ = self + + + + + + +""" + IMPORTING CUSTOM MODULES/METHODS +""" + +from nets.OGBMOL_graph_classification.load_net import gnn_model # import GNNs +from data.data import LoadData # import dataset + + + + +""" + GPU Setup +""" +def gpu_setup(use_gpu, gpu_id): + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + if torch.cuda.is_available() and use_gpu: + print('cuda available with GPU:',torch.cuda.get_device_name(0)) + device = torch.device("cuda") + else: + print('cuda not available') + device = torch.device("cpu") + return device + + + + + + + +""" + VIEWING MODEL CONFIG AND PARAMS +""" +def view_model_param(MODEL_NAME, net_params): + model = gnn_model(MODEL_NAME, net_params) + total_param = 0 + print("MODEL DETAILS:\n") + # print(model) + for param in model.parameters(): + # print(param.data.size()) + total_param += np.prod(list(param.data.size())) + print('MODEL/Total parameters:', MODEL_NAME, total_param) + return total_param + + +def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs): + t0 = time.time() + per_epoch_time = [] + + DATASET_NAME = dataset.name + + if net_params['pe_init'] == 'lap_pe': + tt = time.time() + print("[!] -LapPE: Initializing graph positional encoding with Laplacian PE.") + dataset._add_lap_positional_encodings(net_params['pos_enc_dim']) + print("[!] Time taken: ", time.time()-tt) + elif net_params['pe_init'] == 'rand_walk': + tt = time.time() + print("[!] -LSPE: Initializing graph positional encoding with rand walk features.") + dataset._init_positional_encodings(net_params['pos_enc_dim'], net_params['pe_init']) + print("[!] Time taken: ", time.time()-tt) + + tt = time.time() + print("[!] -LSPE (For viz later): Adding lapeigvecs to key 'eigvec' for every graph.") + dataset._add_eig_vecs(net_params['pos_enc_dim']) + print("[!] Time taken: ", time.time()-tt) + + if MODEL_NAME in ['SAN', 'GraphiT']: + if net_params['full_graph']: + st = time.time() + print("[!] Adding full graph connectivity..") + dataset._make_full_graph() if MODEL_NAME == 'SAN' else dataset._make_full_graph((net_params['p_steps'], net_params['gamma'])) + print('Time taken to add full graph connectivity: ',time.time()-st) + + trainset, valset, testset = dataset.train, dataset.val, dataset.test + + evaluator = dataset.evaluator + + root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir = dirs + device = net_params['device'] + + # Write the network and optimization hyper-parameters in folder config/ + with open(write_config_file + '.txt', 'w') as f: + f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) + + log_dir = os.path.join(root_log_dir, "RUN_" + str(0)) + writer = SummaryWriter(log_dir=log_dir) + + # setting seeds + random.seed(params['seed']) + np.random.seed(params['seed']) + torch.manual_seed(params['seed']) + if device.type == 'cuda': + torch.cuda.manual_seed(params['seed']) + torch.cuda.manual_seed_all(params['seed']) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + print("Training Graphs: ", len(trainset)) + print("Validation Graphs: ", len(valset)) + print("Test Graphs: ", len(testset)) + + model = gnn_model(MODEL_NAME, net_params) + model = model.to(device) + + optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + factor=params['lr_reduce_factor'], + patience=params['lr_schedule_patience'], + verbose=True) + + epoch_train_losses, epoch_val_losses = [], [] + epoch_train_accs, epoch_val_accs, epoch_test_accs = [], [], [] + + # import train functions for all GNNs + from train.train_OGBMOL_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network + + train_loader = DataLoader(trainset, num_workers=4, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate, pin_memory=True) + val_loader = DataLoader(valset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True) + test_loader = DataLoader(testset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True) + + # At any point you can hit Ctrl + C to break out of training early. + try: + with tqdm(range(params['epochs'])) as t: + for epoch in t: + + t.set_description('Epoch %d' % epoch) + + start = time.time() + + epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch, evaluator) + + epoch_val_loss, epoch_val_acc, __ = evaluate_network(model, device, val_loader, epoch, evaluator) + _, epoch_test_acc, __ = evaluate_network(model, device, test_loader, epoch, evaluator) + del __ + + epoch_train_losses.append(epoch_train_loss) + epoch_val_losses.append(epoch_val_loss) + epoch_train_accs.append(epoch_train_acc) + epoch_val_accs.append(epoch_val_acc) + epoch_test_accs.append(epoch_test_acc) + + writer.add_scalar('train/_loss', epoch_train_loss, epoch) + writer.add_scalar('val/_loss', epoch_val_loss, epoch) + writer.add_scalar('train/_avg_prec', epoch_train_acc, epoch) + writer.add_scalar('val/_avg_prec', epoch_val_acc, epoch) + writer.add_scalar('test/_avg_prec', epoch_test_acc, epoch) + writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) + + if dataset.name == "ogbg-moltox21": + t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'], + train_loss=epoch_train_loss, val_loss=epoch_val_loss, + train_AUC=epoch_train_acc, val_AUC=epoch_val_acc, + test_AUC=epoch_test_acc) + elif dataset.name == "ogbg-molpcba": + t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'], + train_loss=epoch_train_loss, val_loss=epoch_val_loss, + train_AP=epoch_train_acc, val_AP=epoch_val_acc, + test_AP=epoch_test_acc) + + per_epoch_time.append(time.time()-start) + + # Saving checkpoint + ckpt_dir = os.path.join(root_ckpt_dir, "RUN_") + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch))) + + files = glob.glob(ckpt_dir + '/*.pkl') + for file in files: + epoch_nb = file.split('_')[-1] + epoch_nb = int(epoch_nb.split('.')[0]) + if epoch_nb < epoch-1: + os.remove(file) + + scheduler.step(epoch_val_loss) + + if optimizer.param_groups[0]['lr'] < params['min_lr']: + print("\n!! LR EQUAL TO MIN LR SET.") + break + + # Stop training after params['max_time'] hours + if time.time()-t0 > params['max_time']*3600: + print('-' * 89) + print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time'])) + break + + except KeyboardInterrupt: + print('-' * 89) + print('Exiting from training early because of KeyboardInterrupt') + + # ___, __, g_outs_train = evaluate_network(model, device, train_loader, epoch, evaluator) + ___, __, g_outs_test = evaluate_network(model, device, test_loader, epoch, evaluator) + del ___ + del __ + + # OGB: Test scores at best val epoch + epoch_best = epoch_val_accs.index(max(epoch_val_accs)) + + test_acc = epoch_test_accs[epoch_best] + train_acc = epoch_train_accs[epoch_best] + val_acc = epoch_val_accs[epoch_best] + + if dataset.name == "ogbg-moltox21": + print("Test AUC: {:.4f}".format(test_acc)) + print("Train AUC: {:.4f}".format(train_acc)) + print("Val AUC: {:.4f}".format(val_acc)) + elif dataset.name == "ogbg-molpcba": + print("Test Avg Precision: {:.4f}".format(test_acc)) + print("Train Avg Precision: {:.4f}".format(train_acc)) + print("Convergence Time (Epochs): {:.4f}".format(epoch)) + print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-t0)) + print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) + + if net_params['pe_init'] == 'rand_walk' and g_outs_test is not None: + # Visualize actual and predicted/learned eigenvecs + from utils.plot_util import plot_graph_eigvec + if not os.path.exists(viz_dir): + os.makedirs(viz_dir) + + sample_graph_ids = [153,103,123] + + for f_idx, graph_id in enumerate(sample_graph_ids): + + # Test graphs + g_dgl = g_outs_test[graph_id] + + f = plt.figure(f_idx, figsize=(12,6)) + + plt1 = f.add_subplot(121) + plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True) + + plt2 = f.add_subplot(122) + plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True) + + f.savefig(viz_dir+'/test'+str(graph_id)+'.jpg') + + writer.close() + + """ + Write the results in out_dir/results folder + """ + if dataset.name == "ogbg-moltox21": + with open(write_file_name + '.txt', 'w') as f: + f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n + FINAL RESULTS\nTEST AUC: {:.4f}\nTRAIN AUC: {:.4f}\nVAL AUC: {:.4f}\n\n + Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\ + .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], + test_acc, train_acc, val_acc, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time))) + elif dataset.name == "ogbg-molpcba": + with open(write_file_name + '.txt', 'w') as f: + f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n + FINAL RESULTS\nTEST AVG PRECISION: {:.4f}\nTRAIN AVG PRECISION: {:.4f}\nVAL AVG PRECISION: {:.4f}\n\n + Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\ + .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], + test_acc, train_acc, val_acc, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time))) + + + + +def main(): + """ + USER CONTROLS + """ + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") + parser.add_argument('--gpu_id', help="Please give a value for gpu id") + parser.add_argument('--model', help="Please give a value for model name") + parser.add_argument('--dataset', help="Please give a value for dataset name") + parser.add_argument('--out_dir', help="Please give a value for out_dir") + parser.add_argument('--seed', help="Please give a value for seed") + parser.add_argument('--epochs', help="Please give a value for epochs") + parser.add_argument('--batch_size', help="Please give a value for batch_size") + parser.add_argument('--init_lr', help="Please give a value for init_lr") + parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") + parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") + parser.add_argument('--min_lr', help="Please give a value for min_lr") + parser.add_argument('--weight_decay', help="Please give a value for weight_decay") + parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") + parser.add_argument('--L', help="Please give a value for L") + parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") + parser.add_argument('--out_dim', help="Please give a value for out_dim") + parser.add_argument('--residual', help="Please give a value for residual") + parser.add_argument('--edge_feat', help="Please give a value for edge_feat") + parser.add_argument('--readout', help="Please give a value for readout") + parser.add_argument('--kernel', help="Please give a value for kernel") + parser.add_argument('--n_heads', help="Please give a value for n_heads") + parser.add_argument('--gated', help="Please give a value for gated") + parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") + parser.add_argument('--dropout', help="Please give a value for dropout") + parser.add_argument('--layer_norm', help="Please give a value for layer_norm") + parser.add_argument('--batch_norm', help="Please give a value for batch_norm") + parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator") + parser.add_argument('--data_mode', help="Please give a value for data_mode") + parser.add_argument('--num_pool', help="Please give a value for num_pool") + parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block") + parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim") + parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio") + parser.add_argument('--linkpred', help="Please give a value for linkpred") + parser.add_argument('--cat', help="Please give a value for cat") + parser.add_argument('--self_loop', help="Please give a value for self_loop") + parser.add_argument('--max_time', help="Please give a value for max_time") + parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim") + parser.add_argument('--alpha_loss', help="Please give a value for alpha_loss") + parser.add_argument('--lambda_loss', help="Please give a value for lambda_loss") + parser.add_argument('--pe_init', help="Please give a value for pe_init") + args = parser.parse_args() + with open(args.config) as f: + config = json.load(f) + # device + if args.gpu_id is not None: + config['gpu']['id'] = int(args.gpu_id) + config['gpu']['use'] = True + device = gpu_setup(config['gpu']['use'], config['gpu']['id']) + # model, dataset, out_dir + if args.model is not None: + MODEL_NAME = args.model + else: + MODEL_NAME = config['model'] + if args.dataset is not None: + DATASET_NAME = args.dataset + else: + DATASET_NAME = config['dataset'] + dataset = LoadData(DATASET_NAME) + if args.out_dir is not None: + out_dir = args.out_dir + else: + out_dir = config['out_dir'] + # parameters + params = config['params'] + if args.seed is not None: + params['seed'] = int(args.seed) + if args.epochs is not None: + params['epochs'] = int(args.epochs) + if args.batch_size is not None: + params['batch_size'] = int(args.batch_size) + if args.init_lr is not None: + params['init_lr'] = float(args.init_lr) + if args.lr_reduce_factor is not None: + params['lr_reduce_factor'] = float(args.lr_reduce_factor) + if args.lr_schedule_patience is not None: + params['lr_schedule_patience'] = int(args.lr_schedule_patience) + if args.min_lr is not None: + params['min_lr'] = float(args.min_lr) + if args.weight_decay is not None: + params['weight_decay'] = float(args.weight_decay) + if args.print_epoch_interval is not None: + params['print_epoch_interval'] = int(args.print_epoch_interval) + if args.max_time is not None: + params['max_time'] = float(args.max_time) + # network parameters + net_params = config['net_params'] + net_params['device'] = device + net_params['gpu_id'] = config['gpu']['id'] + net_params['batch_size'] = params['batch_size'] + if args.L is not None: + net_params['L'] = int(args.L) + if args.hidden_dim is not None: + net_params['hidden_dim'] = int(args.hidden_dim) + if args.out_dim is not None: + net_params['out_dim'] = int(args.out_dim) + if args.residual is not None: + net_params['residual'] = True if args.residual=='True' else False + if args.edge_feat is not None: + net_params['edge_feat'] = True if args.edge_feat=='True' else False + if args.readout is not None: + net_params['readout'] = args.readout + if args.kernel is not None: + net_params['kernel'] = int(args.kernel) + if args.n_heads is not None: + net_params['n_heads'] = int(args.n_heads) + if args.gated is not None: + net_params['gated'] = True if args.gated=='True' else False + if args.in_feat_dropout is not None: + net_params['in_feat_dropout'] = float(args.in_feat_dropout) + if args.dropout is not None: + net_params['dropout'] = float(args.dropout) + if args.layer_norm is not None: + net_params['layer_norm'] = True if args.layer_norm=='True' else False + if args.batch_norm is not None: + net_params['batch_norm'] = True if args.batch_norm=='True' else False + if args.sage_aggregator is not None: + net_params['sage_aggregator'] = args.sage_aggregator + if args.data_mode is not None: + net_params['data_mode'] = args.data_mode + if args.num_pool is not None: + net_params['num_pool'] = int(args.num_pool) + if args.gnn_per_block is not None: + net_params['gnn_per_block'] = int(args.gnn_per_block) + if args.embedding_dim is not None: + net_params['embedding_dim'] = int(args.embedding_dim) + if args.pool_ratio is not None: + net_params['pool_ratio'] = float(args.pool_ratio) + if args.linkpred is not None: + net_params['linkpred'] = True if args.linkpred=='True' else False + if args.cat is not None: + net_params['cat'] = True if args.cat=='True' else False + if args.self_loop is not None: + net_params['self_loop'] = True if args.self_loop=='True' else False + if args.pos_enc_dim is not None: + net_params['pos_enc_dim'] = int(args.pos_enc_dim) + if args.alpha_loss is not None: + net_params['alpha_loss'] = float(args.alpha_loss) + if args.lambda_loss is not None: + net_params['lambda_loss'] = float(args.lambda_loss) + if args.pe_init is not None: + net_params['pe_init'] = args.pe_init + + + + # OGBMOL* + num_classes = dataset.dataset.num_tasks # provided by OGB dataset class + net_params['n_classes'] = num_classes + + if MODEL_NAME == 'PNA': + D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g, label in + dataset.train]) + net_params['avg_d'] = dict(lin=torch.mean(D), + exp=torch.mean(torch.exp(torch.div(1, D)) - 1), + log=torch.mean(torch.log(D + 1))) + + root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + viz_dir = out_dir + 'viz/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir + + if not os.path.exists(out_dir + 'results'): + os.makedirs(out_dir + 'results') + + if not os.path.exists(out_dir + 'configs'): + os.makedirs(out_dir + 'configs') + + net_params['total_param'] = view_model_param(MODEL_NAME, net_params) + train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs) + + + + + + + + +main() + + + + + + + + + + + + + + + diff --git a/main_ZINC_graph_regression.py b/main_ZINC_graph_regression.py new file mode 100644 index 0000000..91d028e --- /dev/null +++ b/main_ZINC_graph_regression.py @@ -0,0 +1,453 @@ + + + + + +""" + IMPORTING LIBS +""" +import dgl + +import numpy as np +import os +import socket +import time +import random +import glob +import argparse, json +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.optim as optim +from torch.utils.data import DataLoader + +from tensorboardX import SummaryWriter +from tqdm import tqdm + +import matplotlib +import matplotlib.pyplot as plt + + +class DotDict(dict): + def __init__(self, **kwds): + self.update(kwds) + self.__dict__ = self + + + + + + +""" + IMPORTING CUSTOM MODULES/METHODS +""" +from nets.ZINC_graph_regression.load_net import gnn_model # import all GNNS +from data.data import LoadData # import dataset + + + + +""" + GPU Setup +""" +def gpu_setup(use_gpu, gpu_id): + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + if torch.cuda.is_available() and use_gpu: + print('cuda available with GPU:',torch.cuda.get_device_name(0)) + device = torch.device("cuda") + else: + print('cuda not available') + device = torch.device("cpu") + return device + + + + + + + + +""" + VIEWING MODEL CONFIG AND PARAMS +""" +def view_model_param(MODEL_NAME, net_params): + model = gnn_model(MODEL_NAME, net_params) + total_param = 0 + print("MODEL DETAILS:\n") + #print(model) + for param in model.parameters(): + # print(param.data.size()) + total_param += np.prod(list(param.data.size())) + print('MODEL/Total parameters:', MODEL_NAME, total_param) + return total_param + + +""" + TRAINING CODE +""" + +def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs): + t0 = time.time() + per_epoch_time = [] + + DATASET_NAME = dataset.name + + if net_params['pe_init'] == 'lap_pe': + tt = time.time() + print("[!] -LapPE: Initializing graph positional encoding with Laplacian PE.") + dataset._add_lap_positional_encodings(net_params['pos_enc_dim']) + print("[!] Time taken: ", time.time()-tt) + elif net_params['pe_init'] == 'rand_walk': + tt = time.time() + print("[!] -LSPE: Initializing graph positional encoding with rand walk features.") + dataset._init_positional_encodings(net_params['pos_enc_dim'], net_params['pe_init']) + print("[!] Time taken: ", time.time()-tt) + + tt = time.time() + print("[!] -LSPE (For viz later): Adding lapeigvecs to key 'eigvec' for every graph.") + dataset._add_eig_vecs(net_params['pos_enc_dim']) + print("[!] Time taken: ", time.time()-tt) + + if MODEL_NAME in ['SAN', 'GraphiT']: + if net_params['full_graph']: + st = time.time() + print("[!] Adding full graph connectivity..") + dataset._make_full_graph() if MODEL_NAME == 'SAN' else dataset._make_full_graph((net_params['p_steps'], net_params['gamma'])) + print('Time taken to add full graph connectivity: ',time.time()-st) + + trainset, valset, testset = dataset.train, dataset.val, dataset.test + + root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir = dirs + device = net_params['device'] + + # Write the network and optimization hyper-parameters in folder config/ + with open(write_config_file + '.txt', 'w') as f: + f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) + + log_dir = os.path.join(root_log_dir, "RUN_" + str(0)) + writer = SummaryWriter(log_dir=log_dir) + + # setting seeds + random.seed(params['seed']) + np.random.seed(params['seed']) + torch.manual_seed(params['seed']) + if device.type == 'cuda': + torch.cuda.manual_seed(params['seed']) + torch.cuda.manual_seed_all(params['seed']) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + print("Training Graphs: ", len(trainset)) + print("Validation Graphs: ", len(valset)) + print("Test Graphs: ", len(testset)) + + model = gnn_model(MODEL_NAME, net_params) + model = model.to(device) + + optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + factor=params['lr_reduce_factor'], + patience=params['lr_schedule_patience'], + verbose=True) + + epoch_train_losses, epoch_val_losses = [], [] + epoch_train_MAEs, epoch_val_MAEs = [], [] + + # import train functions for all GNNs + from train.train_ZINC_graph_regression import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network + + train_loader = DataLoader(trainset, num_workers=4, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate) + val_loader = DataLoader(valset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) + test_loader = DataLoader(testset, num_workers=4, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) + + # At any point you can hit Ctrl + C to break out of training early. + try: + with tqdm(range(params['epochs'])) as t: + for epoch in t: + + t.set_description('Epoch %d' % epoch) + + start = time.time() + + epoch_train_loss, epoch_train_mae, optimizer = train_epoch(model, optimizer, device, train_loader, epoch) + + epoch_val_loss, epoch_val_mae, __ = evaluate_network(model, device, val_loader, epoch) + epoch_test_loss, epoch_test_mae, __ = evaluate_network(model, device, test_loader, epoch) + del __ + + epoch_train_losses.append(epoch_train_loss) + epoch_val_losses.append(epoch_val_loss) + epoch_train_MAEs.append(epoch_train_mae) + epoch_val_MAEs.append(epoch_val_mae) + + writer.add_scalar('train/_loss', epoch_train_loss, epoch) + writer.add_scalar('val/_loss', epoch_val_loss, epoch) + writer.add_scalar('train/_mae', epoch_train_mae, epoch) + writer.add_scalar('val/_mae', epoch_val_mae, epoch) + writer.add_scalar('test/_mae', epoch_test_mae, epoch) + writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) + + + t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'], + train_loss=epoch_train_loss, val_loss=epoch_val_loss, + train_MAE=epoch_train_mae, val_MAE=epoch_val_mae, + test_MAE=epoch_test_mae) + + + per_epoch_time.append(time.time()-start) + + # Saving checkpoint + ckpt_dir = os.path.join(root_ckpt_dir, "RUN_") + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch))) + + files = glob.glob(ckpt_dir + '/*.pkl') + for file in files: + epoch_nb = file.split('_')[-1] + epoch_nb = int(epoch_nb.split('.')[0]) + if epoch_nb < epoch-1: + os.remove(file) + + scheduler.step(epoch_val_loss) + + if optimizer.param_groups[0]['lr'] < params['min_lr']: + print("\n!! LR EQUAL TO MIN LR SET.") + break + + # Stop training after params['max_time'] hours + if time.time()-t0 > params['max_time']*3600: + print('-' * 89) + print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time'])) + break + + except KeyboardInterrupt: + print('-' * 89) + print('Exiting from training early because of KeyboardInterrupt') + + test_loss_lapeig, test_mae, g_outs_test = evaluate_network(model, device, test_loader, epoch) + train_loss_lapeig, train_mae, g_outs_train = evaluate_network(model, device, train_loader, epoch) + + print("Test MAE: {:.4f}".format(test_mae)) + print("Train MAE: {:.4f}".format(train_mae)) + print("Convergence Time (Epochs): {:.4f}".format(epoch)) + print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-t0)) + print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) + + + if net_params['pe_init'] == 'rand_walk': + # Visualize actual and predicted/learned eigenvecs + from utils.plot_util import plot_graph_eigvec + if not os.path.exists(viz_dir): + os.makedirs(viz_dir) + + sample_graph_ids = [15,25,45] + + for f_idx, graph_id in enumerate(sample_graph_ids): + + # Test graphs + g_dgl = g_outs_test[graph_id] + + f = plt.figure(f_idx, figsize=(12,6)) + + plt1 = f.add_subplot(121) + plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True) + + plt2 = f.add_subplot(122) + plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True) + + f.savefig(viz_dir+'/test'+str(graph_id)+'.jpg') + + # Train graphs + g_dgl = g_outs_train[graph_id] + + f = plt.figure(f_idx, figsize=(12,6)) + + plt1 = f.add_subplot(121) + plot_graph_eigvec(plt1, graph_id, g_dgl, feature_key='eigvec', actual_eigvecs=True) + + plt2 = f.add_subplot(122) + plot_graph_eigvec(plt2, graph_id, g_dgl, feature_key='p', predicted_eigvecs=True) + + f.savefig(viz_dir+'/train'+str(graph_id)+'.jpg') + + writer.close() + + """ + Write the results in out_dir/results folder + """ + with open(write_file_name + '.txt', 'w') as f: + f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n + FINAL RESULTS\nTEST MAE: {:.4f}\nTRAIN MAE: {:.4f}\n\n + Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\ + .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], + test_mae, train_mae, epoch, (time.time()-t0)/3600, np.mean(per_epoch_time))) + + + + + +def main(): + """ + USER CONTROLS + """ + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") + parser.add_argument('--gpu_id', help="Please give a value for gpu id") + parser.add_argument('--model', help="Please give a value for model name") + parser.add_argument('--dataset', help="Please give a value for dataset name") + parser.add_argument('--out_dir', help="Please give a value for out_dir") + parser.add_argument('--seed', help="Please give a value for seed") + parser.add_argument('--epochs', help="Please give a value for epochs") + parser.add_argument('--batch_size', help="Please give a value for batch_size") + parser.add_argument('--init_lr', help="Please give a value for init_lr") + parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") + parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") + parser.add_argument('--min_lr', help="Please give a value for min_lr") + parser.add_argument('--weight_decay', help="Please give a value for weight_decay") + parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") + parser.add_argument('--L', help="Please give a value for L") + parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") + parser.add_argument('--out_dim', help="Please give a value for out_dim") + parser.add_argument('--residual', help="Please give a value for residual") + parser.add_argument('--edge_feat', help="Please give a value for edge_feat") + parser.add_argument('--readout', help="Please give a value for readout") + parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") + parser.add_argument('--dropout', help="Please give a value for dropout") + parser.add_argument('--layer_norm', help="Please give a value for layer_norm") + parser.add_argument('--batch_norm', help="Please give a value for batch_norm") + parser.add_argument('--max_time', help="Please give a value for max_time") + parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim") + parser.add_argument('--pos_enc', help="Please give a value for pos_enc") + parser.add_argument('--alpha_loss', help="Please give a value for alpha_loss") + parser.add_argument('--lambda_loss', help="Please give a value for lambda_loss") + parser.add_argument('--pe_init', help="Please give a value for pe_init") + args = parser.parse_args() + with open(args.config) as f: + config = json.load(f) + + # device + if args.gpu_id is not None: + config['gpu']['id'] = int(args.gpu_id) + config['gpu']['use'] = True + device = gpu_setup(config['gpu']['use'], config['gpu']['id']) + # model, dataset, out_dir + if args.model is not None: + MODEL_NAME = args.model + else: + MODEL_NAME = config['model'] + if args.dataset is not None: + DATASET_NAME = args.dataset + else: + DATASET_NAME = config['dataset'] + dataset = LoadData(DATASET_NAME) + if args.out_dir is not None: + out_dir = args.out_dir + else: + out_dir = config['out_dir'] + # parameters + params = config['params'] + if args.seed is not None: + params['seed'] = int(args.seed) + if args.epochs is not None: + params['epochs'] = int(args.epochs) + if args.batch_size is not None: + params['batch_size'] = int(args.batch_size) + if args.init_lr is not None: + params['init_lr'] = float(args.init_lr) + if args.lr_reduce_factor is not None: + params['lr_reduce_factor'] = float(args.lr_reduce_factor) + if args.lr_schedule_patience is not None: + params['lr_schedule_patience'] = int(args.lr_schedule_patience) + if args.min_lr is not None: + params['min_lr'] = float(args.min_lr) + if args.weight_decay is not None: + params['weight_decay'] = float(args.weight_decay) + if args.print_epoch_interval is not None: + params['print_epoch_interval'] = int(args.print_epoch_interval) + if args.max_time is not None: + params['max_time'] = float(args.max_time) + # network parameters + net_params = config['net_params'] + net_params['device'] = device + net_params['gpu_id'] = config['gpu']['id'] + net_params['batch_size'] = params['batch_size'] + if args.L is not None: + net_params['L'] = int(args.L) + if args.hidden_dim is not None: + net_params['hidden_dim'] = int(args.hidden_dim) + if args.out_dim is not None: + net_params['out_dim'] = int(args.out_dim) + if args.residual is not None: + net_params['residual'] = True if args.residual=='True' else False + if args.edge_feat is not None: + net_params['edge_feat'] = True if args.edge_feat=='True' else False + if args.readout is not None: + net_params['readout'] = args.readout + if args.in_feat_dropout is not None: + net_params['in_feat_dropout'] = float(args.in_feat_dropout) + if args.dropout is not None: + net_params['dropout'] = float(args.dropout) + if args.layer_norm is not None: + net_params['layer_norm'] = True if args.layer_norm=='True' else False + if args.batch_norm is not None: + net_params['batch_norm'] = True if args.batch_norm=='True' else False + if args.pos_enc is not None: + net_params['pos_enc'] = True if args.pos_enc=='True' else False + if args.pos_enc_dim is not None: + net_params['pos_enc_dim'] = int(args.pos_enc_dim) + if args.alpha_loss is not None: + net_params['alpha_loss'] = float(args.alpha_loss) + if args.lambda_loss is not None: + net_params['lambda_loss'] = float(args.lambda_loss) + if args.pe_init is not None: + net_params['pe_init'] = args.pe_init + + + # ZINC + net_params['num_atom_type'] = dataset.num_atom_type + net_params['num_bond_type'] = dataset.num_bond_type + + if MODEL_NAME == 'PNA': + D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g in + dataset.train.graph_lists]) + net_params['avg_d'] = dict(lin=torch.mean(D), + exp=torch.mean(torch.exp(torch.div(1, D)) - 1), + log=torch.mean(torch.log(D + 1))) + + root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + viz_dir = out_dir + 'viz/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file, viz_dir + + if not os.path.exists(out_dir + 'results'): + os.makedirs(out_dir + 'results') + + if not os.path.exists(out_dir + 'configs'): + os.makedirs(out_dir + 'configs') + + net_params['total_param'] = view_model_param(MODEL_NAME, net_params) + train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs) + + + + + + + + +main() + + + + + diff --git a/nets/OGBMOL_graph_classification/gatedgcn_net.py b/nets/OGBMOL_graph_classification/gatedgcn_net.py new file mode 100644 index 0000000..3bb2c9c --- /dev/null +++ b/nets/OGBMOL_graph_classification/gatedgcn_net.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl + +from scipy import sparse as sp +from scipy.sparse.linalg import norm + +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder + +""" + GatedGCN and GatedGCN-LSPE + +""" + +from layers.gatedgcn_layer import GatedGCNLayer +from layers.gatedgcn_lspe_layer import GatedGCNLSPELayer +from layers.mlp_readout_layer import MLPReadout + + +class GatedGCNNet(nn.Module): + def __init__(self, net_params): + super().__init__() + hidden_dim = net_params['hidden_dim'] + out_dim = net_params['out_dim'] + n_classes = net_params['n_classes'] + dropout = net_params['dropout'] + n_layers = net_params['L'] + self.readout = net_params['readout'] + self.batch_norm = net_params['batch_norm'] + self.residual = net_params['residual'] + self.device = net_params['device'] + self.pe_init = net_params['pe_init'] + self.n_classes = net_params['n_classes'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk', 'lap_pe']: + self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim) + + self.atom_encoder = AtomEncoder(hidden_dim) + self.bond_encoder = BondEncoder(hidden_dim) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ GatedGCNLSPELayer(hidden_dim, hidden_dim, dropout, self.batch_norm, self.residual) + for _ in range(n_layers-1) ]) + self.layers.append(GatedGCNLSPELayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual)) + else: + # NoPE or LapPE + self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, self.batch_norm, self.residual) + for _ in range(n_layers-1) ]) + self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual)) + + self.MLP_layer = MLPReadout(out_dim, n_classes) + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim) + + self.g = None # For util; To be accessed in loss() function + + def forward(self, g, h, p, e, snorm_n): + + h = self.atom_encoder(h) + e = self.bond_encoder(e) + + if self.pe_init in ['rand_walk', 'lap_pe']: + p = self.embedding_p(p) + + if self.pe_init == 'lap_pe': + h = h + p + p = None + + for conv in self.layers: + h, p, e = conv(g, h, p, e, snorm_n) + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = self.p_out(p) + g.ndata['p'] = p + + if self.use_lapeig_loss: + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + if self.pe_init == 'rand_walk': + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + if self.n_classes == 128: + return_g = None # not passing PCBA graphs due to memory + else: + return_g = g + + return self.MLP_layer(hg), return_g + + def loss(self, pred, labels): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/OGBMOL_graph_classification/graphit_net.py b/nets/OGBMOL_graph_classification/graphit_net.py new file mode 100644 index 0000000..ec94cd5 --- /dev/null +++ b/nets/OGBMOL_graph_classification/graphit_net.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl + +from scipy import sparse as sp +from scipy.sparse.linalg import norm + +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder + +""" + GraphiT-GT and GraphiT-GT-LSPE + +""" + +from layers.graphit_gt_layer import GraphiT_GT_Layer +from layers.graphit_gt_lspe_layer import GraphiT_GT_LSPE_Layer +from layers.mlp_readout_layer import MLPReadout + +class GraphiTNet(nn.Module): + def __init__(self, net_params): + super().__init__() + + full_graph = net_params['full_graph'] + gamma = net_params['gamma'] + self.adaptive_edge_PE = net_params['adaptive_edge_PE'] + + GT_layers = net_params['L'] + GT_hidden_dim = net_params['hidden_dim'] + GT_out_dim = net_params['out_dim'] + GT_n_heads = net_params['n_heads'] + + self.residual = net_params['residual'] + self.readout = net_params['readout'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + + self.readout = net_params['readout'] + self.layer_norm = net_params['layer_norm'] + self.batch_norm = net_params['batch_norm'] + + n_classes = net_params['n_classes'] + self.device = net_params['device'] + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim) + + self.embedding_h = AtomEncoder(GT_hidden_dim) + self.embedding_e = BondEncoder(GT_hidden_dim) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ]) + self.layers.append(GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE)) + else: + # NoPE + self.layers = nn.ModuleList([ GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ]) + self.layers.append(GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE)) + + self.MLP_layer = MLPReadout(GT_out_dim, n_classes) + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim) + + self.g = None # For util; To be accessed in loss() function + + def forward(self, g, h, p, e, snorm_n): + + h = self.embedding_h(h) + e = self.embedding_e(e) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + for conv in self.layers: + h, p = conv(g, h, p, e, snorm_n) + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = self.p_out(p) + g.ndata['p'] = p + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, pred, labels): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/OGBMOL_graph_classification/load_net.py b/nets/OGBMOL_graph_classification/load_net.py new file mode 100644 index 0000000..9c8c3b5 --- /dev/null +++ b/nets/OGBMOL_graph_classification/load_net.py @@ -0,0 +1,31 @@ +""" + Utility file to select GraphNN model as + selected by the user +""" + +from nets.OGBMOL_graph_classification.gatedgcn_net import GatedGCNNet +from nets.OGBMOL_graph_classification.pna_net import PNANet +from nets.OGBMOL_graph_classification.san_net import SANNet +from nets.OGBMOL_graph_classification.graphit_net import GraphiTNet + +def GatedGCN(net_params): + return GatedGCNNet(net_params) + +def PNA(net_params): + return PNANet(net_params) + +def SAN(net_params): + return SANNet(net_params) + +def GraphiT(net_params): + return GraphiTNet(net_params) + +def gnn_model(MODEL_NAME, net_params): + models = { + 'GatedGCN': GatedGCN, + 'PNA': PNA, + 'SAN': SAN, + 'GraphiT': GraphiT + } + + return models[MODEL_NAME](net_params) \ No newline at end of file diff --git a/nets/OGBMOL_graph_classification/pna_net.py b/nets/OGBMOL_graph_classification/pna_net.py new file mode 100644 index 0000000..00bf82c --- /dev/null +++ b/nets/OGBMOL_graph_classification/pna_net.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl + +from scipy import sparse as sp +from scipy.sparse.linalg import norm + +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder + +""" + PNA-LSPE + +""" + +from layers.pna_layer import PNANoTowersLayer as PNALayer +from layers.pna_lspe_layer import PNANoTowersLSPELayer as PNALSPELayer +from layers.mlp_readout_layer import MLPReadout2 as MLPReadout + + +class PNANet(nn.Module): + def __init__(self, net_params): + super().__init__() + + hidden_dim = net_params['hidden_dim'] + out_dim = net_params['out_dim'] + n_classes = net_params['n_classes'] + dropout = net_params['dropout'] + self.dropout_2 = net_params['dropout_2'] + + n_layers = net_params['L'] + self.readout = net_params['readout'] + self.graph_norm = net_params['graph_norm'] + self.batch_norm = net_params['batch_norm'] + self.aggregators = net_params['aggregators'] + self.scalers = net_params['scalers'] + self.avg_d = net_params['avg_d'] + self.residual = net_params['residual'] + self.edge_feat = net_params['edge_feat'] + edge_dim = net_params['edge_dim'] + pretrans_layers = net_params['pretrans_layers'] + posttrans_layers = net_params['posttrans_layers'] + self.gru_enable = net_params['gru'] + device = net_params['device'] + self.device = device + self.pe_init = net_params['pe_init'] + self.n_classes = net_params['n_classes'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim) + + self.embedding_h = AtomEncoder(emb_dim=hidden_dim) + + if self.edge_feat: + self.embedding_e = BondEncoder(emb_dim=edge_dim) + + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList( + [PNALSPELayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm, + batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, residual=self.residual, + edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss) + for _ in range(n_layers - 1)]) + self.layers.append(PNALSPELayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm, + batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, + residual=self.residual, edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss)) + + else: + # NoPE + self.layers = nn.ModuleList( + [PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm, + batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, residual=self.residual, + edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss) + for _ in range(n_layers - 1)]) + self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm, + batch_norm=self.batch_norm, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, + residual=self.residual, edge_features=self.edge_feat, edge_dim=edge_dim, use_lapeig_loss=self.use_lapeig_loss)) + + if self.gru_enable: + self.gru = GRU(hidden_dim, hidden_dim, device) + + self.MLP_layer = MLPReadout(out_dim, n_classes, self.dropout_2) + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim) + + self.g = None # For util; To be accessed in loss() function + + def forward(self, g, h, p, e, snorm_n): + + h = self.embedding_h(h) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + if self.edge_feat: + e = self.embedding_e(e) + + for i, conv in enumerate(self.layers): + h_t, p_t = conv(g, h, p, e, snorm_n) + if self.gru_enable and i != len(self.layers) - 1: + h_t = self.gru(h, h_t) + h, p = h_t, p_t + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = F.dropout(p, self.dropout_2, training=self.training) + p = self.p_out(p) + g.ndata['p'] = p + + if self.use_lapeig_loss: + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + if self.pe_init == 'rand_walk': + # Concat h and p + hp = torch.cat((g.ndata['h'],g.ndata['p']),dim=-1) + hp = F.dropout(hp, self.dropout_2, training=self.training) + hp = self.Whp(hp) + g.ndata['h'] = hp + + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + if self.n_classes == 128: + return_g = None # not passing PCBA graphs due to memory + else: + return_g = g + + return self.MLP_layer(hg), return_g + + def loss(self, pred, labels): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels) + + if self.use_lapeig_loss: + # Loss B: Laplacian Eigenvector Loss -------------------------------------------- + g = self.g + n = g.number_of_nodes() + + # Laplacian + A = g.adjacency_matrix(scipy_fmt="csr") + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(n) - N * A * N + + p = g.ndata['p'] + pT = torch.transpose(p, 1, 0) + loss_b_1 = torch.trace(torch.mm(torch.mm(pT, torch.Tensor(L.todense()).to(self.device)), p)) + + # Correct batch-graph wise loss_b_2 implementation; using a block diagonal matrix + bg = dgl.unbatch(g) + batch_size = len(bg) + P = sp.block_diag([bg[i].ndata['p'].detach().cpu() for i in range(batch_size)]) + PTP_In = P.T * P - sp.eye(P.shape[1]) + loss_b_2 = torch.tensor(norm(PTP_In, 'fro')**2).float().to(self.device) + + loss_b = ( loss_b_1 + self.lambda_loss * loss_b_2 ) / ( self.pos_enc_dim * batch_size * n) + + del bg, P, PTP_In, loss_b_1, loss_b_2 + + loss = loss_a + self.alpha_loss * loss_b + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/OGBMOL_graph_classification/san_net.py b/nets/OGBMOL_graph_classification/san_net.py new file mode 100644 index 0000000..62f9791 --- /dev/null +++ b/nets/OGBMOL_graph_classification/san_net.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl + +from scipy import sparse as sp +from scipy.sparse.linalg import norm + +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder + +""" + SAN-GT and SAN-GT-LSPE + +""" + +from layers.san_gt_layer import SAN_GT_Layer +from layers.san_gt_lspe_layer import SAN_GT_LSPE_Layer +from layers.mlp_readout_layer import MLPReadout + + +class SANNet(nn.Module): + def __init__(self, net_params): + super().__init__() + + full_graph = net_params['full_graph'] + init_gamma = net_params['init_gamma'] + + # learn gamma + self.gamma = nn.Parameter(torch.FloatTensor([init_gamma])) + + GT_layers = net_params['L'] + GT_hidden_dim = net_params['hidden_dim'] + GT_out_dim = net_params['out_dim'] + GT_n_heads = net_params['n_heads'] + + self.residual = net_params['residual'] + self.readout = net_params['readout'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + + self.readout = net_params['readout'] + self.layer_norm = net_params['layer_norm'] + self.batch_norm = net_params['batch_norm'] + + n_classes = net_params['n_classes'] + self.device = net_params['device'] + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim) + + self.embedding_h = AtomEncoder(GT_hidden_dim) + self.embedding_e = BondEncoder(GT_hidden_dim) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ]) + self.layers.append(SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual)) + else: + # NoPE + self.layers = nn.ModuleList([ SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ]) + self.layers.append(SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual)) + + self.MLP_layer = MLPReadout(GT_out_dim, n_classes) + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim) + + self.g = None # For util; To be accessed in loss() function + + def forward(self, g, h, p, e, snorm_n): + + h = self.embedding_h(h) + e = self.embedding_e(e) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + for conv in self.layers: + h, p = conv(g, h, p, e, snorm_n) + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = self.p_out(p) + g.ndata['p'] = p + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, pred, labels): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = torch.nn.BCEWithLogitsLoss()(pred, labels) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/ZINC_graph_regression/gatedgcn_net.py b/nets/ZINC_graph_regression/gatedgcn_net.py new file mode 100644 index 0000000..d3150a1 --- /dev/null +++ b/nets/ZINC_graph_regression/gatedgcn_net.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl + +from scipy import sparse as sp +from scipy.sparse.linalg import norm + +""" + GatedGCN and GatedGCN-LSPE + +""" +from layers.gatedgcn_layer import GatedGCNLayer +from layers.gatedgcn_lspe_layer import GatedGCNLSPELayer +from layers.mlp_readout_layer import MLPReadout + +class GatedGCNNet(nn.Module): + def __init__(self, net_params): + super().__init__() + num_atom_type = net_params['num_atom_type'] + num_bond_type = net_params['num_bond_type'] + hidden_dim = net_params['hidden_dim'] + out_dim = net_params['out_dim'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + self.n_layers = net_params['L'] + self.readout = net_params['readout'] + self.batch_norm = net_params['batch_norm'] + self.residual = net_params['residual'] + self.edge_feat = net_params['edge_feat'] + self.device = net_params['device'] + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk', 'lap_pe']: + self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim) + + self.embedding_h = nn.Embedding(num_atom_type, hidden_dim) + + if self.edge_feat: + self.embedding_e = nn.Embedding(num_bond_type, hidden_dim) + else: + self.embedding_e = nn.Linear(1, hidden_dim) + + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ GatedGCNLSPELayer(hidden_dim, hidden_dim, dropout, + self.batch_norm, residual=self.residual) for _ in range(self.n_layers-1) ]) + self.layers.append(GatedGCNLSPELayer(hidden_dim, out_dim, dropout, self.batch_norm, residual=self.residual)) + else: + # NoPE or LapPE + self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, + self.batch_norm, residual=self.residual, graph_norm=False) for _ in range(self.n_layers-1) ]) + self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, residual=self.residual, graph_norm=False)) + + self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim) + + self.g = None # For util; To be accessed in loss() function + + + def forward(self, g, h, p, e, snorm_n): + + # input embedding + h = self.embedding_h(h) + h = self.in_feat_dropout(h) + + if self.pe_init in ['rand_walk', 'lap_pe']: + p = self.embedding_p(p) + + if self.pe_init == 'lap_pe': + h = h + p + p = None + + if not self.edge_feat: # edge feature set to 1 + e = torch.ones(e.size(0),1).to(self.device) + e = self.embedding_e(e) + + + # convnets + for conv in self.layers: + h, p, e = conv(g, h, p, e, snorm_n) + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + p = self.p_out(p) + g.ndata['p'] = p + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + # readout + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, scores, targets): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = nn.L1Loss()(scores, targets) + + if self.use_lapeig_loss: + # Loss B: Laplacian Eigenvector Loss -------------------------------------------- + g = self.g + n = g.number_of_nodes() + + # Laplacian + A = g.adjacency_matrix(scipy_fmt="csr") + N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) + L = sp.eye(n) - N * A * N + + p = g.ndata['p'] + pT = torch.transpose(p, 1, 0) + loss_b_1 = torch.trace(torch.mm(torch.mm(pT, torch.Tensor(L.todense()).to(self.device)), p)) + + # Correct batch-graph wise loss_b_2 implementation; using a block diagonal matrix + bg = dgl.unbatch(g) + batch_size = len(bg) + P = sp.block_diag([bg[i].ndata['p'].detach().cpu() for i in range(batch_size)]) + PTP_In = P.T * P - sp.eye(P.shape[1]) + loss_b_2 = torch.tensor(norm(PTP_In, 'fro')**2).float().to(self.device) + + loss_b = ( loss_b_1 + self.lambda_loss * loss_b_2 ) / ( self.pos_enc_dim * batch_size * n) + + del bg, P, PTP_In, loss_b_1, loss_b_2 + + loss = loss_a + self.alpha_loss * loss_b + else: + loss = loss_a + + return loss + + + + + \ No newline at end of file diff --git a/nets/ZINC_graph_regression/graphit_net.py b/nets/ZINC_graph_regression/graphit_net.py new file mode 100644 index 0000000..48b28f1 --- /dev/null +++ b/nets/ZINC_graph_regression/graphit_net.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import numpy as np + +from scipy import sparse as sp + +""" + GraphiT-GT and GraphiT-GT-LSPE + +""" +from layers.graphit_gt_layer import GraphiT_GT_Layer +from layers.graphit_gt_lspe_layer import GraphiT_GT_LSPE_Layer +from layers.mlp_readout_layer import MLPReadout + +class GraphiTNet(nn.Module): + def __init__(self, net_params): + super().__init__() + + num_atom_type = net_params['num_atom_type'] + num_bond_type = net_params['num_bond_type'] + + full_graph = net_params['full_graph'] + gamma = net_params['gamma'] + self.adaptive_edge_PE = net_params['adaptive_edge_PE'] + + GT_layers = net_params['L'] + GT_hidden_dim = net_params['hidden_dim'] + GT_out_dim = net_params['out_dim'] + GT_n_heads = net_params['n_heads'] + + self.residual = net_params['residual'] + self.readout = net_params['readout'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + + self.readout = net_params['readout'] + self.layer_norm = net_params['layer_norm'] + self.batch_norm = net_params['batch_norm'] + + self.device = net_params['device'] + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim) + + self.embedding_h = nn.Embedding(num_atom_type, GT_hidden_dim) + self.embedding_e = nn.Embedding(num_bond_type, GT_hidden_dim) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, dropout, + self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ]) + self.layers.append(GraphiT_GT_LSPE_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, dropout, + self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE)) + else: + # NoPE + self.layers = nn.ModuleList([ GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, dropout, + self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE) for _ in range(GT_layers-1) ]) + self.layers.append(GraphiT_GT_Layer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, dropout, + self.layer_norm, self.batch_norm, self.residual, self.adaptive_edge_PE)) + + self.MLP_layer = MLPReadout(GT_out_dim, 1) # 1 out dim since regression problem + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim) + + self.g = None # For util; To be accessed in loss() function + + + def forward(self, g, h, p, e, snorm_n): + + # input embedding + h = self.embedding_h(h) + e = self.embedding_e(e) + + h = self.in_feat_dropout(h) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + # GNN + for conv in self.layers: + h, p = conv(g, h, p, e, snorm_n) + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = self.p_out(p) + g.ndata['p'] = p + + if self.use_lapeig_loss and self.pe_init == 'rand_walk': + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + if self.pe_init == 'rand_walk': + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + # readout + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, scores, targets): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = nn.L1Loss()(scores, targets) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/ZINC_graph_regression/load_net.py b/nets/ZINC_graph_regression/load_net.py new file mode 100644 index 0000000..03dd401 --- /dev/null +++ b/nets/ZINC_graph_regression/load_net.py @@ -0,0 +1,31 @@ +""" + Utility file to select GraphNN model as + selected by the user +""" + +from nets.ZINC_graph_regression.gatedgcn_net import GatedGCNNet +from nets.ZINC_graph_regression.pna_net import PNANet +from nets.ZINC_graph_regression.san_net import SANNet +from nets.ZINC_graph_regression.graphit_net import GraphiTNet + +def GatedGCN(net_params): + return GatedGCNNet(net_params) + +def PNA(net_params): + return PNANet(net_params) + +def SAN(net_params): + return SANNet(net_params) + +def GraphiT(net_params): + return GraphiTNet(net_params) + +def gnn_model(MODEL_NAME, net_params): + models = { + 'GatedGCN': GatedGCN, + 'PNA': PNA, + 'SAN': SAN, + 'GraphiT': GraphiT + } + + return models[MODEL_NAME](net_params) \ No newline at end of file diff --git a/nets/ZINC_graph_regression/pna_net.py b/nets/ZINC_graph_regression/pna_net.py new file mode 100644 index 0000000..8ca7f2b --- /dev/null +++ b/nets/ZINC_graph_regression/pna_net.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import dgl +from scipy import sparse as sp +from scipy.sparse.linalg import norm + + +""" + PNA and PNA-LSPE + +""" + +from layers.pna_layer import PNALayer +from layers.pna_lspe_layer import PNALSPELayer +from layers.pna_utils import GRU +from layers.mlp_readout_layer import MLPReadout + +class PNANet(nn.Module): + def __init__(self, net_params): + super().__init__() + num_atom_type = net_params['num_atom_type'] + num_bond_type = net_params['num_bond_type'] + hidden_dim = net_params['hidden_dim'] + out_dim = net_params['out_dim'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + n_layers = net_params['L'] + self.readout = net_params['readout'] + self.graph_norm = net_params['graph_norm'] + self.batch_norm = net_params['batch_norm'] + self.residual = net_params['residual'] + self.aggregators = net_params['aggregators'] + self.scalers = net_params['scalers'] + self.avg_d = net_params['avg_d'] + self.towers = net_params['towers'] + self.divide_input_first = net_params['divide_input_first'] + self.divide_input_last = net_params['divide_input_last'] + self.edge_feat = net_params['edge_feat'] + edge_dim = net_params['edge_dim'] + pretrans_layers = net_params['pretrans_layers'] + posttrans_layers = net_params['posttrans_layers'] + self.gru_enable = net_params['gru'] + device = net_params['device'] + self.device = device + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, hidden_dim) + + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + + self.embedding_h = nn.Embedding(num_atom_type, hidden_dim) + + if self.edge_feat: + self.embedding_e = nn.Embedding(num_bond_type, edge_dim) + + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([PNALSPELayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, + graph_norm=self.graph_norm, batch_norm=self.batch_norm, + residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, + avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat, + edge_dim=edge_dim, divide_input=self.divide_input_first, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _ + in range(n_layers - 1)]) + self.layers.append(PNALSPELayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, + graph_norm=self.graph_norm, batch_norm=self.batch_norm, + residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, + avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last, + edge_features=self.edge_feat, edge_dim=edge_dim, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers)) + else: + # NoPE + self.layers = nn.ModuleList([PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, + graph_norm=self.graph_norm, batch_norm=self.batch_norm, + residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, + avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat, + edge_dim=edge_dim, divide_input=self.divide_input_first, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _ + in range(n_layers - 1)]) + self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, + graph_norm=self.graph_norm, batch_norm=self.batch_norm, + residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, + avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last, + edge_features=self.edge_feat, edge_dim=edge_dim, + pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers)) + + if self.gru_enable: + self.gru = GRU(hidden_dim, hidden_dim, device) + + self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(out_dim+self.pos_enc_dim, out_dim) + + self.g = None # For util; To be accessed in loss() function + + def forward(self, g, h, p, e, snorm_n): + h = self.embedding_h(h) + h = self.in_feat_dropout(h) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + if self.edge_feat: + e = self.embedding_e(e) + + for i, conv in enumerate(self.layers): + h_t, p_t = conv(g, h, p, e, snorm_n) + if self.gru_enable and i != len(self.layers) - 1: + h_t = self.gru(h, h_t) + h, p = h_t, p_t + + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + p = self.p_out(p) + g.ndata['p'] = p + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, scores, targets): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = nn.L1Loss()(scores, targets) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/nets/ZINC_graph_regression/san_net.py b/nets/ZINC_graph_regression/san_net.py new file mode 100644 index 0000000..0951e49 --- /dev/null +++ b/nets/ZINC_graph_regression/san_net.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl +import numpy as np + +from scipy import sparse as sp + +""" + SAN-GT and SAN-GT-LSPE + +""" +from layers.san_gt_layer import SAN_GT_Layer +from layers.san_gt_lspe_layer import SAN_GT_LSPE_Layer +from layers.mlp_readout_layer import MLPReadout + +class SANNet(nn.Module): + def __init__(self, net_params): + super().__init__() + + num_atom_type = net_params['num_atom_type'] + num_bond_type = net_params['num_bond_type'] + + full_graph = net_params['full_graph'] + init_gamma = net_params['init_gamma'] + + # learn gamma + self.gamma = nn.Parameter(torch.FloatTensor([init_gamma])) + + GT_layers = net_params['L'] + GT_hidden_dim = net_params['hidden_dim'] + GT_out_dim = net_params['out_dim'] + GT_n_heads = net_params['n_heads'] + + self.residual = net_params['residual'] + self.readout = net_params['readout'] + in_feat_dropout = net_params['in_feat_dropout'] + dropout = net_params['dropout'] + + self.readout = net_params['readout'] + self.layer_norm = net_params['layer_norm'] + self.batch_norm = net_params['batch_norm'] + + self.device = net_params['device'] + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + self.pe_init = net_params['pe_init'] + + self.use_lapeig_loss = net_params['use_lapeig_loss'] + self.lambda_loss = net_params['lambda_loss'] + self.alpha_loss = net_params['alpha_loss'] + + self.pos_enc_dim = net_params['pos_enc_dim'] + + if self.pe_init in ['rand_walk']: + self.embedding_p = nn.Linear(self.pos_enc_dim, GT_hidden_dim) + + self.embedding_h = nn.Embedding(num_atom_type, GT_hidden_dim) + self.embedding_e = nn.Embedding(num_bond_type, GT_hidden_dim) + + if self.pe_init == 'rand_walk': + # LSPE + self.layers = nn.ModuleList([ SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ]) + self.layers.append(SAN_GT_LSPE_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual)) + else: + # NoPE + self.layers = nn.ModuleList([ SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers-1) ]) + self.layers.append(SAN_GT_Layer(self.gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, + dropout, self.layer_norm, self.batch_norm, self.residual)) + + self.MLP_layer = MLPReadout(GT_out_dim, 1) # 1 out dim since regression problem + + if self.pe_init == 'rand_walk': + self.p_out = nn.Linear(GT_out_dim, self.pos_enc_dim) + self.Whp = nn.Linear(GT_out_dim+self.pos_enc_dim, GT_out_dim) + + self.g = None # For util; To be accessed in loss() function + + + def forward(self, g, h, p, e, snorm_n): + + # input embedding + h = self.embedding_h(h) + e = self.embedding_e(e) + + h = self.in_feat_dropout(h) + + if self.pe_init in ['rand_walk']: + p = self.embedding_p(p) + + # GNN + for conv in self.layers: + h, p = conv(g, h, p, e, snorm_n) + g.ndata['h'] = h + + if self.pe_init == 'rand_walk': + p = self.p_out(p) + g.ndata['p'] = p + + if self.use_lapeig_loss and self.pe_init == 'rand_walk': + # Implementing p_g = p_g - torch.mean(p_g, dim=0) + means = dgl.mean_nodes(g, 'p') + batch_wise_p_means = means.repeat_interleave(g.batch_num_nodes(), 0) + p = p - batch_wise_p_means + + # Implementing p_g = p_g / torch.norm(p_g, p=2, dim=0) + g.ndata['p'] = p + g.ndata['p2'] = g.ndata['p']**2 + norms = dgl.sum_nodes(g, 'p2') + norms = torch.sqrt(norms+1e-6) + batch_wise_p_l2_norms = norms.repeat_interleave(g.batch_num_nodes(), 0) + p = p / batch_wise_p_l2_norms + g.ndata['p'] = p + + if self.pe_init == 'rand_walk': + # Concat h and p + hp = self.Whp(torch.cat((g.ndata['h'],g.ndata['p']),dim=-1)) + g.ndata['h'] = hp + + # readout + if self.readout == "sum": + hg = dgl.sum_nodes(g, 'h') + elif self.readout == "max": + hg = dgl.max_nodes(g, 'h') + elif self.readout == "mean": + hg = dgl.mean_nodes(g, 'h') + else: + hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes + + self.g = g # For util; To be accessed in loss() function + + return self.MLP_layer(hg), g + + def loss(self, scores, targets): + + # Loss A: Task loss ------------------------------------------------------------- + loss_a = nn.L1Loss()(scores, targets) + + if self.use_lapeig_loss: + raise NotImplementedError + else: + loss = loss_a + + return loss \ No newline at end of file diff --git a/scripts/OGBMOL/script_MOLPCBA_all.sh b/scripts/OGBMOL/script_MOLPCBA_all.sh new file mode 100644 index 0000000..d9231f2 --- /dev/null +++ b/scripts/OGBMOL/script_MOLPCBA_all.sh @@ -0,0 +1,65 @@ +#!/bin/bash + + +############ +# Usage +############ + +# bash script_MOLPCBA_all.sh + + +#################################### +# MOLPCBA - 4 SEED RUNS OF EACH EXPTS +#################################### + +seed0=41 +seed1=95 +seed2=12 +seed3=35 +code=main_OGBMOL_graph_classification.py +dataset=OGBG-MOLPCBA +tmux new -s gnn_lspe_PCBA -d +tmux send-keys "source ~/.bashrc" C-m +tmux send-keys "source activate gnn_lspe" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_LapPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLPCBA_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLPCBA_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLPCBA_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLPCBA_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLPCBA_LSPE.json' & +wait" C-m +tmux send-keys "tmux kill-session -t gnn_lspe_PCBA" C-m + + + + + + + + + + + diff --git a/scripts/OGBMOL/script_MOLTOX21_all.sh b/scripts/OGBMOL/script_MOLTOX21_all.sh new file mode 100644 index 0000000..4aa698e --- /dev/null +++ b/scripts/OGBMOL/script_MOLTOX21_all.sh @@ -0,0 +1,95 @@ +#!/bin/bash + + +############ +# Usage +############ + +# bash script_MOLTOX21_all.sh + + +#################################### +# MOLTOX21 - 4 SEED RUNS OF EACH EXPTS +#################################### + +seed0=41 +seed1=95 +seed2=12 +seed3=35 +code=main_OGBMOL_graph_classification.py +dataset=OGBG-MOLTOX21 +tmux new -s gnn_lspe_TOX21 -d +tmux send-keys "source ~/.bashrc" C-m +tmux send-keys "source activate gnn_lspe" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_LapPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_MOLTOX21_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_MOLTOX21_LSPE_withLapEigLoss.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_MOLTOX21_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_MOLTOX21_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_MOLTOX21_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_MOLTOX21_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_MOLTOX21_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_MOLTOX21_LSPE.json' & +wait" C-m +tmux send-keys "tmux kill-session -t gnn_lspe_TOX21" C-m + + + + + + + + + + + diff --git a/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb b/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb new file mode 100644 index 0000000..714efbf --- /dev/null +++ b/scripts/StatisticalResults/generate_statistics_OGBMOL.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f_dir = \"../../out/PNA_MOLTOX21_NoPE/results/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file = []\n", + "dataset = []\n", + "model = []\n", + "layer = []\n", + "params = []\n", + "acc_test = []\n", + "acc_train = []\n", + "convergence = []\n", + "total_time = []\n", + "epoch_time = []\n", + "\n", + "for filename in os.listdir(f_dir):\n", + "\n", + " if filename[-4:] == \".txt\":\n", + " file.append( filename )\n", + " \n", + " with open(os.path.join(f_dir, filename), \"r\") as f:\n", + " lines = f.readlines()\n", + "\n", + " for line in lines:\n", + " # print('h1c',line)\n", + "\n", + " if line[:9] == \"Dataset: \":\n", + " dataset.append( line[9:-2] )\n", + "\n", + " if line[:7] == \"Model: \":\n", + " model.append( line[7:-1] )\n", + "\n", + " if line[:17] == \"net_params={'L': \":\n", + " layer.append( line[17:18] )\n", + " \n", + " if line[:56] == \"net_params={'full_graph': True, 'init_gamma': 0.1, 'L': \":\n", + " layer.append( line[56:58] )\n", + " \n", + " if line[:37] == \"net_params={'full_graph': True, 'L': \":\n", + " layer.append( line[37:39] )\n", + " \n", + " if line[:18] == \"Total Parameters: \":\n", + " params.append( line[18:-1] )\n", + "\n", + " if line[:10] == \"TEST AUC: \":\n", + " acc_test.append( float(line[10:-1]) )\n", + " \n", + " if line[:11] == \"TRAIN AUC: \":\n", + " acc_train.append( float(line[11:-1]) )\n", + " \n", + " if line[:35] == \" Convergence Time (Epochs): \":\n", + " convergence.append( float(line[35:-1]) )\n", + "\n", + " if line[:18] == \"Total Time Taken: \":\n", + " total_time.append( float(line[18:-4]) )\n", + "\n", + " if line[:24] == 'Average Time Per Epoch: ':\n", + " epoch_time.append( float(line[24:-2]) )\n", + " \n", + " \n", + " \n", + " \n", + "# print('file',file)\n", + "# print('dataset',dataset)\n", + "# print('model',model)\n", + "# print('layer',layer)\n", + "# print('params',params)\n", + "# print('acc_test',acc_test)\n", + "# print('acc_train',acc_train)\n", + "# print('convergence',convergence)\n", + "# print('total_time',total_time)\n", + "# print('epoch_time',epoch_time)\n", + "\n", + "\n", + "\n", + "\n", + "list_datasets = ['ogbg-moltox21']\n", + "#print('list_datasets',list_datasets)\n", + "\n", + "list_gnns = ['GatedGCN', 'PNA', 'SAN', 'GraphiT']\n", + "#print('list_gnns',list_gnns)\n", + "\n", + "\n", + " \n", + "for data in list_datasets:\n", + " #print(data)\n", + "\n", + " for gnn in list_gnns:\n", + " #print('gnn:',gnn)\n", + "\n", + " acc_test_one_gnn = []\n", + " acc_train_one_gnn = []\n", + " convergence_one_gnn = []\n", + " total_time_one_gnn = []\n", + " epoch_time_one_gnn = []\n", + " nb_seeds = 0\n", + "\n", + " for i in range(len(file)):\n", + " #print(params[i])\n", + " \n", + " if data==dataset[i] and gnn==model[i]:\n", + " params_one_gnn = params[i]\n", + " acc_test_one_gnn.append(acc_test[i])\n", + " acc_train_one_gnn.append(acc_train[i])\n", + " convergence_one_gnn.append(convergence[i])\n", + " total_time_one_gnn.append(total_time[i])\n", + " epoch_time_one_gnn.append(epoch_time[i])\n", + " L = layer[i]\n", + " nb_seeds = nb_seeds + 1\n", + "\n", + " #print(params_one_gnn)\n", + " if len(acc_test_one_gnn)>0:\n", + " print(acc_test_one_gnn)\n", + " latex_str = f\"{data} & {nb_seeds} & {gnn} & {L} & {params_one_gnn} & {np.mean(acc_test_one_gnn):.3f}$\\pm${np.std(acc_test_one_gnn):.3f} & {np.mean(acc_train_one_gnn):.3f}$\\pm${np.std(acc_train_one_gnn):.3f} & {np.mean(convergence_one_gnn):.2f} & {np.mean(epoch_time_one_gnn):.2f}s/{np.mean(total_time_one_gnn):.2f}hr\"\n", + " print(\"\\nDataset & #Seeds & Model & L & Param & Acc_test & Acc_train & Speed & Epoch/Time\\n{}\".format(latex_str,nb_seeds))\n", + "\n", + " \n", + "\n", + "print(\"\\n\")\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/scripts/StatisticalResults/generate_statistics_ZINC.ipynb b/scripts/StatisticalResults/generate_statistics_ZINC.ipynb new file mode 100644 index 0000000..01ff269 --- /dev/null +++ b/scripts/StatisticalResults/generate_statistics_ZINC.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f_dir = \"../../out/GraphiT_ZINC_LSPE_noLapEigLoss/results/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file = []\n", + "dataset = []\n", + "model = []\n", + "layer = []\n", + "params = []\n", + "acc_test = []\n", + "acc_train = []\n", + "convergence = []\n", + "total_time = []\n", + "epoch_time = []\n", + "\n", + "for filename in os.listdir(f_dir):\n", + "\n", + " if filename[-4:] == \".txt\":\n", + " file.append( filename )\n", + " \n", + " with open(os.path.join(f_dir, filename), \"r\") as f:\n", + " lines = f.readlines()\n", + "\n", + " for line in lines:\n", + " #print('h1c',line)\n", + "\n", + " if line[:9] == \"Dataset: \":\n", + " dataset.append( line[9:-2] )\n", + "\n", + " if line[:7] == \"Model: \":\n", + " model.append( line[7:-1] )\n", + "\n", + " if line[:17] == \"net_params={'L': \":\n", + " layer.append( line[17:18] )\n", + " \n", + " if line[:56] == \"net_params={'full_graph': True, 'init_gamma': 0.1, 'L': \":\n", + " layer.append( line[56:58])\n", + " \n", + " if line[:37] == \"net_params={'full_graph': True, 'L': \":\n", + " layer.append( line[37:39])\n", + " \n", + " if line[:18] == \"Total Parameters: \":\n", + " params.append( line[18:-1] )\n", + "\n", + " if line[:10] == \"TEST MAE: \":\n", + " acc_test.append( float(line[10:-1]) )\n", + " \n", + " if line[:11] == \"TRAIN MAE: \":\n", + " acc_train.append( float(line[11:-1]) )\n", + " \n", + " if line[4:31] == \"Convergence Time (Epochs): \":\n", + " convergence.append( float(line[31:-1]) )\n", + "\n", + " if line[:18] == \"Total Time Taken: \":\n", + " total_time.append( float(line[18:-4]) )\n", + "\n", + " if line[:24] == 'Average Time Per Epoch: ':\n", + " epoch_time.append( float(line[24:-2]) )\n", + " \n", + " \n", + " \n", + " \n", + "# print('file',file)\n", + "# print('dataset',dataset)\n", + "# print('model',model)\n", + "# print('layer',layer)\n", + "# print('params',params)\n", + "# print('acc_test',acc_test)\n", + "# print('acc_train',acc_train)\n", + "# print('convergence',convergence)\n", + "# print('total_time',total_time)\n", + "# print('epoch_time',epoch_time)\n", + "\n", + "\n", + "\n", + "\n", + "list_datasets = ['ZINC']\n", + "#print('list_datasets',list_datasets)\n", + "\n", + "list_gnns = ['GatedGCN', 'PNA', 'SAN', 'GraphiT']\n", + "#print('list_gnns',list_gnns)\n", + "\n", + "\n", + "for data in list_datasets:\n", + " #print(data)\n", + "\n", + " for gnn in list_gnns:\n", + " #print('gnn:',gnn)\n", + "\n", + " acc_test_one_gnn = []\n", + " acc_train_one_gnn = []\n", + " convergence_one_gnn = []\n", + " total_time_one_gnn = []\n", + " epoch_time_one_gnn = []\n", + " nb_seeds = 0\n", + "\n", + " for i in range(len(file)):\n", + " #print(params[i])\n", + " if data==dataset[i] and gnn==model[i]:\n", + " params_one_gnn = params[i]\n", + " acc_test_one_gnn.append(acc_test[i])\n", + " acc_train_one_gnn.append(acc_train[i])\n", + " convergence_one_gnn.append(convergence[i])\n", + " total_time_one_gnn.append(total_time[i])\n", + " epoch_time_one_gnn.append(epoch_time[i])\n", + " L = layer[i]\n", + " nb_seeds = nb_seeds + 1\n", + "\n", + " #print(params_one_gnn)\n", + " if len(acc_test_one_gnn)>0:\n", + " print(acc_test_one_gnn)\n", + " latex_str = f\"{data} & {nb_seeds} & {gnn} & {L} & {params_one_gnn} & {np.mean(acc_test_one_gnn):.3f}$\\pm${np.std(acc_test_one_gnn):.3f} & {np.mean(acc_train_one_gnn):.3f}$\\pm${np.std(acc_train_one_gnn):.3f} & {np.mean(convergence_one_gnn):.2f} & {np.mean(epoch_time_one_gnn):.2f}s/{np.mean(total_time_one_gnn):.2f}hr\"\n", + " print(\"\\nDataset & #Seeds & Model & L & Param & Acc_test & Acc_train & Speed & Epoch/Time\\n{}\".format(latex_str,nb_seeds))\n", + "\n", + " \n", + "\n", + "print(\"\\n\")\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/scripts/TensorBoard/script_tensorboard.sh b/scripts/TensorBoard/script_tensorboard.sh new file mode 100644 index 0000000..81ddd33 --- /dev/null +++ b/scripts/TensorBoard/script_tensorboard.sh @@ -0,0 +1,21 @@ +#!/bin/bash + + +# bash script_tensorboard.sh + + + + + +tmux new -s tensorboard -d +tmux send-keys "source activate gnn_lspe" C-m +tmux send-keys "tensorboard --logdir out/ --port 6006" C-m + + + + + + + + + diff --git a/scripts/ZINC/script_ZINC_all.sh b/scripts/ZINC/script_ZINC_all.sh new file mode 100644 index 0000000..30cb7bc --- /dev/null +++ b/scripts/ZINC/script_ZINC_all.sh @@ -0,0 +1,97 @@ +#!/bin/bash + + +############ +# Usage +############ + +# bash script_ZINC_all.sh + + +#################################### +# ZINC - 4 SEED RUNS OF EACH EXPTS +#################################### + +seed0=41 +seed1=95 +seed2=12 +seed3=35 +code=main_ZINC_graph_regression.py +dataset=ZINC +tmux new -s gnn_lspe_ZINC -d +tmux send-keys "source ~/.bashrc" C-m +tmux send-keys "source activate gnn_lspe" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LapPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LapPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LapPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LapPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GatedGCN_ZINC_LSPE_withLapEigLoss.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_ZINC_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/PNA_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/PNA_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/PNA_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/PNA_ZINC_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_ZINC_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/SAN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/SAN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/SAN_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/SAN_ZINC_LSPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_ZINC_NoPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_ZINC_NoPE.json' & +wait" C-m +tmux send-keys " +python $code --dataset $dataset --gpu_id 0 --seed $seed0 --config 'configs/GraphiT_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 1 --seed $seed1 --config 'configs/GraphiT_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 2 --seed $seed2 --config 'configs/GraphiT_ZINC_LSPE.json' & +python $code --dataset $dataset --gpu_id 3 --seed $seed3 --config 'configs/GraphiT_ZINC_LSPE.json' & +wait" C-m +tmux send-keys "tmux kill-session -t gnn_lspe_ZINC" C-m + + + + + + + + + + + + + diff --git a/train/metrics.py b/train/metrics.py new file mode 100644 index 0000000..b584da0 --- /dev/null +++ b/train/metrics.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sklearn.metrics import confusion_matrix +from sklearn.metrics import f1_score +import numpy as np + + +def MAE(scores, targets): + MAE = F.l1_loss(scores, targets) + MAE = MAE.detach().item() + return MAE + + +def accuracy_TU(scores, targets): + scores = scores.detach().argmax(dim=1) + acc = (scores==targets).float().sum().item() + return acc + + +def accuracy_MNIST_CIFAR(scores, targets): + scores = scores.detach().argmax(dim=1) + acc = (scores==targets).float().sum().item() + return acc + +def accuracy_CITATION_GRAPH(scores, targets): + scores = scores.detach().argmax(dim=1) + acc = (scores==targets).float().sum().item() + acc = acc / len(targets) + return acc + + +def accuracy_SBM(scores, targets): + S = targets.cpu().numpy() + C = np.argmax( torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy() , axis=1 ) + CM = confusion_matrix(S,C).astype(np.float32) + nb_classes = CM.shape[0] + targets = targets.cpu().detach().numpy() + nb_non_empty_classes = 0 + pr_classes = np.zeros(nb_classes) + for r in range(nb_classes): + cluster = np.where(targets==r)[0] + if cluster.shape[0] != 0: + pr_classes[r] = CM[r,r]/ float(cluster.shape[0]) + if CM[r,r]>0: + nb_non_empty_classes += 1 + else: + pr_classes[r] = 0.0 + acc = 100.* np.sum(pr_classes)/ float(nb_classes) + return acc + + +def binary_f1_score(scores, targets): + """Computes the F1 score using scikit-learn for binary class labels. + + Returns the F1 score for the positive class, i.e. labelled '1'. + """ + y_true = targets.cpu().numpy() + y_pred = scores.argmax(dim=1).cpu().numpy() + return f1_score(y_true, y_pred, average='binary') + + +def accuracy_VOC(scores, targets): + scores = scores.detach().argmax(dim=1).cpu() + targets = targets.cpu().detach().numpy() + acc = f1_score(scores, targets, average='weighted') + return acc diff --git a/train/train_OGBMOL_graph_classification.py b/train/train_OGBMOL_graph_classification.py new file mode 100644 index 0000000..ba74bef --- /dev/null +++ b/train/train_OGBMOL_graph_classification.py @@ -0,0 +1,131 @@ +""" + Utility functions for training one epoch + and evaluating one epoch +""" +import numpy as np +import torch +import torch.nn as nn +import math +from tqdm import tqdm + +import dgl + + +def train_epoch_sparse(model, optimizer, device, data_loader, epoch, evaluator): + model.train() + + epoch_loss = 0 + nb_data = 0 + + y_true = [] + y_pred = [] + + for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader): + optimizer.zero_grad() + + batch_graphs = batch_graphs.to(device) + batch_x = batch_graphs.ndata['feat'].to(device) + batch_e = batch_graphs.edata['feat'].to(device) + batch_labels = batch_labels.to(device) + batch_snorm_n = batch_snorm_n.to(device) + + try: + batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) + except KeyError: + batch_pos_enc = None + + if model.pe_init == 'lap_pe': + sign_flip = torch.rand(batch_pos_enc.size(1)).to(device) + sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0 + batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0) + + batch_pred, __ = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n) + del __ + + # ignore nan labels (unlabeled) when computing training loss + is_labeled = batch_labels == batch_labels + loss = model.loss(batch_pred.to(torch.float32)[is_labeled], batch_labels.to(torch.float32)[is_labeled]) + + loss.backward() + optimizer.step() + + y_true.append(batch_labels.view(batch_pred.shape).detach().cpu()) + y_pred.append(batch_pred.detach().cpu()) + + epoch_loss += loss.detach().item() + nb_data += batch_labels.size(0) + + epoch_loss /= (iter + 1) + + y_true = torch.cat(y_true, dim = 0).numpy() + y_pred = torch.cat(y_pred, dim = 0).numpy() + + # compute performance metric using OGB evaluator + input_dict = {"y_true": y_true, "y_pred": y_pred} + perf = evaluator.eval(input_dict) + + if batch_labels.size(1) == 128: # MOLPCBA + return_perf = perf['ap'] + elif batch_labels.size(1) == 12: # MOLTOX21 + return_perf = perf['rocauc'] + + return epoch_loss, return_perf, optimizer + +def evaluate_network_sparse(model, device, data_loader, epoch, evaluator): + model.eval() + + epoch_loss = 0 + nb_data = 0 + + y_true = [] + y_pred = [] + + out_graphs_for_lapeig_viz = [] + + with torch.no_grad(): + for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(data_loader): + batch_graphs = batch_graphs.to(device) + batch_x = batch_graphs.ndata['feat'].to(device) + batch_e = batch_graphs.edata['feat'].to(device) + batch_labels = batch_labels.to(device) + batch_snorm_n = batch_snorm_n.to(device) + + + try: + batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) + except KeyError: + batch_pos_enc = None + + batch_pred, batch_g = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n) + + + # ignore nan labels (unlabeled) when computing loss + is_labeled = batch_labels == batch_labels + loss = model.loss(batch_pred.to(torch.float32)[is_labeled], batch_labels.to(torch.float32)[is_labeled]) + + y_true.append(batch_labels.view(batch_pred.shape).detach().cpu()) + y_pred.append(batch_pred.detach().cpu()) + + epoch_loss += loss.detach().item() + nb_data += batch_labels.size(0) + + if batch_g is not None: + out_graphs_for_lapeig_viz += dgl.unbatch(batch_g) + else: + out_graphs_for_lapeig_viz = None + + epoch_loss /= (iter + 1) + + y_true = torch.cat(y_true, dim = 0).numpy() + y_pred = torch.cat(y_pred, dim = 0).numpy() + + # compute performance metric using OGB evaluator + input_dict = {"y_true": y_true, "y_pred": y_pred} + perf = evaluator.eval(input_dict) + + if batch_labels.size(1) == 128: # MOLPCBA + return_perf = perf['ap'] + elif batch_labels.size(1) == 12: # MOLTOX21 + return_perf = perf['rocauc'] + + return epoch_loss, return_perf, out_graphs_for_lapeig_viz \ No newline at end of file diff --git a/train/train_ZINC_graph_regression.py b/train/train_ZINC_graph_regression.py new file mode 100644 index 0000000..40c522c --- /dev/null +++ b/train/train_ZINC_graph_regression.py @@ -0,0 +1,84 @@ +""" + Utility functions for training one epoch + and evaluating one epoch +""" +import torch +import torch.nn as nn +import math + +import dgl + +from train.metrics import MAE + + + +def train_epoch_sparse(model, optimizer, device, data_loader, epoch): + model.train() + epoch_loss = 0 + epoch_train_mae = 0 + nb_data = 0 + gpu_mem = 0 + for iter, (batch_graphs, batch_targets, batch_snorm_n) in enumerate(data_loader): + batch_graphs = batch_graphs.to(device) + batch_x = batch_graphs.ndata['feat'].to(device) # num x feat + batch_e = batch_graphs.edata['feat'].to(device) + batch_targets = batch_targets.to(device) + batch_snorm_n = batch_snorm_n.to(device) + optimizer.zero_grad() + + try: + batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) + except KeyError: + batch_pos_enc = None + + if model.pe_init == 'lap_pe': + sign_flip = torch.rand(batch_pos_enc.size(1)).to(device) + sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0 + batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0) + + batch_scores, __ = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n) + del __ + + loss = model.loss(batch_scores, batch_targets) + loss.backward() + optimizer.step() + epoch_loss += loss.detach().item() + epoch_train_mae += MAE(batch_scores, batch_targets) + nb_data += batch_targets.size(0) + epoch_loss /= (iter + 1) + epoch_train_mae /= (iter + 1) + + return epoch_loss, epoch_train_mae, optimizer + +def evaluate_network_sparse(model, device, data_loader, epoch): + model.eval() + epoch_test_loss = 0 + epoch_test_mae = 0 + nb_data = 0 + out_graphs_for_lapeig_viz = [] + with torch.no_grad(): + for iter, (batch_graphs, batch_targets, batch_snorm_n) in enumerate(data_loader): + batch_graphs = batch_graphs.to(device) + batch_x = batch_graphs.ndata['feat'].to(device) + batch_e = batch_graphs.edata['feat'].to(device) + batch_targets = batch_targets.to(device) + batch_snorm_n = batch_snorm_n.to(device) + + try: + batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) + except KeyError: + batch_pos_enc = None + + batch_scores, batch_g = model.forward(batch_graphs, batch_x, batch_pos_enc, batch_e, batch_snorm_n) + + loss = model.loss(batch_scores, batch_targets) + epoch_test_loss += loss.detach().item() + epoch_test_mae += MAE(batch_scores, batch_targets) + nb_data += batch_targets.size(0) + + out_graphs_for_lapeig_viz += dgl.unbatch(batch_g) + epoch_test_loss /= (iter + 1) + epoch_test_mae /= (iter + 1) + + return epoch_test_loss, epoch_test_mae, out_graphs_for_lapeig_viz + diff --git a/utils/cleaner_main.py b/utils/cleaner_main.py new file mode 100644 index 0000000..af383c8 --- /dev/null +++ b/utils/cleaner_main.py @@ -0,0 +1,102 @@ + +# Clean the main.py file after conversion from notebook. +# Any notebook code is removed from the main.py file. + + +import subprocess + + +def cleaner_main(filename): + + # file names + file_notebook = filename + '.ipynb' + file_python = filename + '.py' + + + # convert notebook to python file + print('Convert ' + file_notebook + ' to ' + file_python) + subprocess.check_output('jupyter nbconvert --to script ' + str(file_notebook) , shell=True) + + print('Clean ' + file_python) + + # open file + with open(file_python, "r") as f_in: + lines_in = f_in.readlines() + + # remove cell indices + lines_in = [ line for i,line in enumerate(lines_in) if '# In[' not in line ] + + # remove comments + lines_in = [ line for i,line in enumerate(lines_in) if line[0]!='#' ] + + # remove "in_ipynb()" function + idx_start_fnc = next((i for i, x in enumerate(lines_in) if 'def in_ipynb' in x), None) + if idx_start_fnc!=None: + idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:4] not in ['\n',' ']), None) + lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ] + list_elements_to_remove = ['in_ipynb()', 'print(notebook_mode)'] + for elem in list_elements_to_remove: + lines_in = [ line for i,line in enumerate(lines_in) if elem not in line ] + + # unindent "if notebook_mode==False" block + idx_start_fnc = next((i for i, x in enumerate(lines_in) if 'if notebook_mode==False' in x), None) + if idx_start_fnc!=None: + idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:8] not in ['\n',' ']), None) + for i in range(idx_start_fnc,idx_end_fnc+1): + lines_in[i] = lines_in[i][4:] + lines_in.pop(idx_start_fnc) + list_elements_to_remove = ['# notebook mode', '# terminal mode'] + for elem in list_elements_to_remove: + lines_in = [ line for i,line in enumerate(lines_in) if elem not in line ] + + # remove remaining "if notebook_mode==True" blocks - single indent + run = True + while run: + idx_start_fnc = next((i for i, x in enumerate(lines_in) if x[:16]=='if notebook_mode'), None) + if idx_start_fnc!=None: + idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:4] not in ['\n',' ']), None) + lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ] + else: + run = False + + # remove "if notebook_mode==True" block - double indents + idx_start_fnc = next((i for i, x in enumerate(lines_in) if x[:20]==' if notebook_mode'), None) + if idx_start_fnc!=None: + idx_end_fnc = idx_start_fnc + next((i for i, x in enumerate(lines_in[idx_start_fnc+1:]) if x[:8] not in ['\n',' ']), None) + lines_in = [ line for i,line in enumerate(lines_in) if i not in range(idx_start_fnc,idx_end_fnc+1) ] + + # prepare main() for terminal mode + idx = next((i for i, x in enumerate(lines_in) if 'def main' in x), None) + if idx!=None: lines_in[idx] = 'def main():' + idx = next((i for i, x in enumerate(lines_in) if x[:5]=='else:'), None) + if idx!=None: lines_in.pop(idx) + idx = next((i for i, x in enumerate(lines_in) if x[:10]==' main()'), None) + if idx!=None: lines_in[idx] = 'main()' + + # remove notebook variables + idx = next((i for i, x in enumerate(lines_in) if 'use_gpu = True' in x), None) + if idx!=None: lines_in.pop(idx) + idx = next((i for i, x in enumerate(lines_in) if 'gpu_id = -1' in x), None) + if idx!=None: lines_in.pop(idx) + idx = next((i for i, x in enumerate(lines_in) if 'device = None' in x), None) + if idx!=None: lines_in.pop(idx) + run = True + while run: + idx = next((i for i, x in enumerate(lines_in) if x[:10]=='MODEL_NAME'), None) + if idx!=None: + lines_in.pop(idx) + else: + run = False + + # save clean file + lines_out = str() + for line in lines_in: lines_out += line + with open(file_python, 'w') as f_out: + f_out.write(lines_out) + + print('Done. ') + + + + + diff --git a/utils/plot_util.py b/utils/plot_util.py new file mode 100644 index 0000000..0c20d65 --- /dev/null +++ b/utils/plot_util.py @@ -0,0 +1,46 @@ +""" + Util function to plot graph with eigenvectors + x-axis: first dim + y-axis: second dim +""" + +import networkx as nx + +def plot_graph_eigvec(plt, g_id, g_dgl, feature_key, actual_eigvecs=False, predicted_eigvecs=False): + + if actual_eigvecs: + plt.set_xlabel('first eigenvec') + plt.set_ylabel('second eigenvec') + else: + plt.set_xlabel('first predicted pe') + plt.set_ylabel('second predicted pe') + + g_dgl = g_dgl.cpu() + g_dgl.ndata['feats'] = g_dgl.ndata[feature_key][:,:2] + g_nx = g_dgl.to_networkx(node_attrs=['feats']) + + labels = {} + for idx, node in enumerate(g_nx.nodes()): + labels[node] = str(idx) + + num_nodes = g_dgl.num_nodes() + num_edges = g_dgl.num_edges() + + edge_list = [] + srcs, dsts = g_dgl.edges() + for edge_i in range(num_edges): + edge_list.append((srcs[edge_i].item(), dsts[edge_i].item())) + + # fig, ax = plt.subplots() + # first 2-dim of eigenvecs are x,y coordinates, and the 3rd dim of eigenvec is plotted as node intensity + # intensities = g_dgl.ndata['feats'][:,2] + nx.draw_networkx_nodes(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), node_color='r', node_size=180, label=list(range(g_dgl.number_of_nodes()))) + nx.draw_networkx_edges(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), edge_list, alpha=0.3) + nx.draw_networkx_labels(g_nx, g_dgl.ndata['feats'][:,:2].numpy(), labels, font_size=16) + plt.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) + + title = "Graph ID: " + str(g_id) + + title += " | Actual eigvecs" if actual_eigvecs else " | Predicted PEs" + plt.title.set_text(title) + \ No newline at end of file diff --git a/utils/visualize_RWPE_studies.ipynb b/utils/visualize_RWPE_studies.ipynb new file mode 100644 index 0000000..63e30a6 --- /dev/null +++ b/utils/visualize_RWPE_studies.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b941618a", + "metadata": {}, + "source": [ + "## ZINC Visualization with Init PE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58a892b7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('../') # go to root folder of the project\n", + "print(os.getcwd())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ab99754", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import networkx as nx\n", + "from itertools import count\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from itertools import count" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2adb4608", + "metadata": {}, + "outputs": [], + "source": [ + "from data.data import LoadData\n", + "zinc_d = LoadData('ZINC')\n", + "\n", + "pos_enc_dim = 24\n", + "zinc_d._init_positional_encodings(pos_enc_dim, 'rand_walk')\n", + "zinc_d._add_eig_vecs(36)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b12e503", + "metadata": {}, + "outputs": [], + "source": [ + "num_nodes = []\n", + "num_unique_RWPEs = []\n", + "num_unique_LapPEs = []\n", + "\n", + "for g_ in zinc_d.val:\n", + " num_nodes.append(g_[0].number_of_nodes())\n", + " num_unique_RWPEs.append(len(torch.unique(g_[0].ndata['pos_enc'], dim=0)))\n", + "for g_ in zinc_d.val:\n", + " num_unique_LapPEs.append(len(torch.unique(g_[0].ndata['eigvec'], dim=0)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edcf8c63", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_init_PE_comparison(num_nodes, num_unique_PEs, PE='RWPE'):\n", + " fig = plt.figure(dpi=100, figsize=(7, 6))\n", + " ax = plt.axes()\n", + " plt.xlabel(\"Number of nodes\", fontsize=10)\n", + " if PE == 'RWPE':\n", + " plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique RWPEs\", fontsize=15)\n", + " plt.ylabel(\"Number of unique RWPEs\", fontsize=10)\n", + " elif PE == 'LapPE':\n", + " plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique LapPEs\", fontsize=15)\n", + " plt.ylabel(\"Number of unique LapPEs\", fontsize=10)\n", + " x = np.array(num_nodes)\n", + " y1 = np.array(num_unique_PEs)\n", + " #plt.xticks(x)\n", + " plt.hist2d(x, y1, (50, 50), cmap=plt.cm.Reds)# plt.cm.jet)\n", + " plt.colorbar()\n", + " plt.xlim([10, 35])\n", + " plt.ylim([10, 35])\n", + " #plt.scatter(x, y1, marker=\"o\", color=\"green\", linewidth=0.5)\n", + " x = np.linspace(9,40,100)\n", + " y = x\n", + " plt.plot(x, y, '-r', )\n", + "\n", + " #fig.savefig('out_ZINC_PE_viz/ZINC_valset_'+PE+'.pdf', bbox_inches='tight') \n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f1de5dc", + "metadata": {}, + "outputs": [], + "source": [ + "plot_init_PE_comparison(num_nodes, num_unique_RWPEs, 'RWPE')\n", + "plot_init_PE_comparison(num_nodes, num_unique_LapPEs, 'LapPE')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "607a71d7", + "metadata": {}, + "outputs": [], + "source": [ + "graph_ids = [212,672] # not equal\n", + "graph_ids = [91,967] # equal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "960aa77f", + "metadata": {}, + "outputs": [], + "source": [ + "for idx in graph_ids:\n", + " fig = plt.figure(dpi=100, figsize=(8, 6))\n", + " g_zinc_trial = zinc_d.val[idx][0]\n", + " g = g_zinc_trial.to_networkx(node_attrs=['pos_enc'])\n", + " #groups = set(nx.get_node_attributes(g,'pos_enc').values())\n", + " groups = torch.unique(g_zinc_trial.ndata['pos_enc'],dim=0)\n", + " mapping = dict(zip(groups,count()))\n", + " nodes = g.nodes()\n", + " colors = []\n", + " for n in nodes:\n", + " for key in mapping.keys():\n", + " if torch.equal(key, g.nodes[n]['pos_enc']):\n", + " color = mapping[key]\n", + " colors.append(color)\n", + " \n", + " pos = nx.spring_layout(g)\n", + " ec = nx.draw_networkx_edges(g, pos, alpha=0.2)\n", + " nc = nx.draw_networkx_nodes(g, pos, nodelist=nodes, node_color=colors, \n", + " node_size=100, cmap=plt.cm.jet)\n", + " plt.colorbar(nc)\n", + " #plt.xlabel(\"ZINC Val id: \" +str(idx), fontsize=16)\n", + " plt.title(\"nodes: \"+ str(num_nodes[idx]) + \" | unique RWPEs: \"+str(num_unique_RWPEs[idx]), fontsize=16)\n", + " #fig.savefig('out_ZINC_PE_viz/ZINC_valset_graph_'+str(idx)+'.pdf', bbox_inches='tight') \n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59af93f9", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e552577", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}