diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index d3a2756..5227e77 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -167,15 +167,15 @@ def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): xbsize = 20 bg = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'time': 1, 'y': bsize, 'x': xbsize}, squeeze_batch_dim=False, ) for ds_batch in bg: assert ds_batch['x'].shape == [1, bsize, xbsize] - + bg2 = BatchGenerator( sample_ds_3d, - input_dims={'y': bsize, 'x': xbsize}, + input_dims={'time': 1, 'y': bsize, 'x': xbsize}, squeeze_batch_dim=True, ) for ds_batch in bg: