Skip to content

Commit

Permalink
redunce number of samples in dummy dataset of test to fix memory issu…
Browse files Browse the repository at this point in the history
…es on conda-forge
  • Loading branch information
stefdoerr committed Jan 25, 2025
1 parent 38a445c commit 15afbcd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_datamodule_create(tmpdir):
dl2 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False)
assert dl1 is not dl2


def test_dataloader_get(tmpdir):
args = load_example_args("graph-network")
args["train_size"] = 800
Expand All @@ -50,6 +51,7 @@ def test_dataloader_get(tmpdir):
# Assert that the dataloader is not empty
assert len(data.train_dataloader()) > 0


@mark.parametrize("energy,forces", [(True, True), (True, False), (False, True)])
@mark.parametrize("has_atomref", [True, False])
def test_datamodule_standardize(energy, forces, has_atomref, tmpdir):
Expand All @@ -60,7 +62,9 @@ def test_datamodule_standardize(energy, forces, has_atomref, tmpdir):
args["test_size"] = 10
args["log_dir"] = tmpdir

dataset = DummyDataset(energy=energy, forces=forces, has_atomref=has_atomref)
dataset = DummyDataset(
num_samples=100, energy=energy, forces=forces, has_atomref=has_atomref
)
data = DataModule(args, dataset=dataset)
data.prepare_data()
data.setup("fit")
Expand Down

0 comments on commit 15afbcd

Please sign in to comment.