Skip to content

Commit

Permalink
adding lr settings to allowed settings + enabling checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
tulga-rdn committed Jan 30, 2025
1 parent e62386c commit a395f81
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ FITTING_SCHEME:

LONG_RANGE_SETTINGS: # torchPME settings
FULL_NEIGHBOUR_LIST: False
PREFACTOR: 14.
PREFACTOR: 1.0
CHARGE_CHANNELS: 1
ATOMIC_SMEARING: 1.
LR_WAVELENGTH: 0.5
Expand Down
5 changes: 4 additions & 1 deletion src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def main():
FITTING_SCHEME = hypers.FITTING_SCHEME

ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS

LONG_RANGE_SETTINGS = hypers.LONG_RANGE_SETTINGS

# set_reproducibility(FITTING_SCHEME.RANDOM_SEED, FITTING_SCHEME.CUDA_DETERMINISTIC)

Expand Down Expand Up @@ -122,7 +124,8 @@ def main():

if hypers.UTILITY_FLAGS.CALCULATION_TYPE == "mlip":
model = PETMLIPWrapper(
model, hypers.MLIP_SETTINGS.USE_ENERGIES, hypers.MLIP_SETTINGS.USE_FORCES
model, hypers.MLIP_SETTINGS.USE_ENERGIES, hypers.MLIP_SETTINGS.USE_FORCES,
LONG_RANGE_SETTINGS, device
)

if FITTING_SCHEME.MULTI_GPU and torch.cuda.is_available():
Expand Down
7 changes: 6 additions & 1 deletion src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def fit_pet(
save_hypers(hypers, f"{output_dir}/{NAME_OF_CALCULATION}/hypers_used.yaml")

print("Convering structures to PyG graphs...", flush=True)
print(len(train_structures))

train_graphs = get_pyg_graphs(
train_structures,
Expand Down Expand Up @@ -557,6 +558,9 @@ def main():
parser.add_argument(
"name_of_calculation", help="Name of this calculation", type=str
)
parser.add_argument(
"--checkpoint", help="checkpoint path", type=str, default=None
)
parser.add_argument("--gpu_id", help="ID of the GPU to use", type=int, default=0)
args = parser.parse_args()

Expand All @@ -572,7 +576,7 @@ def main():

name_of_calculation = args.name_of_calculation

output_dir = "results"
output_dir = "/scratch/izar/sodjarga/results"

hypers_dict = hypers_to_dict(hypers)
fit_pet(
Expand All @@ -582,6 +586,7 @@ def main():
name_of_calculation,
device,
output_dir,
args.checkpoint
)


Expand Down

0 comments on commit a395f81

Please sign in to comment.