Skip to content

Commit

Permalink
squeeze_batch_dim test (attempt 2)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdupuis3 committed Nov 19, 2021
1 parent 0e8f716 commit 749ac26
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@ def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
xbsize = 20
bg = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': xbsize},
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert ds_batch['x'].shape == [1, bsize, xbsize]

bg2 = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': xbsize},
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
squeeze_batch_dim=True,
)
for ds_batch in bg:
Expand Down

0 comments on commit 749ac26

Please sign in to comment.