diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index 1c30307cb98..fb94fbc7ad4 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device): ], [ 2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01 ]], - [[ - 8.1773e-01, 9.5440e-01, 2.4532e+00, - 8.1773e-01, 8.1773e-01, 1.1359e+00 - ], - [ - 8.4689e-01, 1.9176e+00, 1.4715e+00, - 8.4689e-01, 8.4689e-01, 1.3079e+00 - ], - [ - 6.9473e-01, 2.7440e-01, 2.0842e+00, - 6.9473e-01, 6.9473e-01, 7.8619e-01 - ], - [ - 7.6789e-01, 1.5063e+00, 1.6209e+00, - 7.6789e-01, 7.6789e-01, 1.1562e+00 - ], - [ - 3.8760e-01, 1.0300e-02, 8.3569e-09, - 3.8760e-01, 3.8760e-01, 1.9723e-01 - ]]], - dtype=dtype, - device=device) + [[ + 8.1773e-01, 9.5440e-01, 2.4532e+00, + 8.1773e-01, 8.1773e-01, 1.1359e+00 + ], + [ + 8.4689e-01, 1.9176e+00, 1.4715e+00, + 8.4689e-01, 8.4689e-01, 1.3079e+00 + ], + [ + 6.9473e-01, 2.7440e-01, 2.0842e+00, + 6.9473e-01, 6.9473e-01, 7.8619e-01 + ], + [ + 7.6789e-01, 1.5063e+00, 1.6209e+00, + 7.6789e-01, 7.6789e-01, 1.1562e+00 + ], + [ + 3.8760e-01, 1.0300e-02, 8.3569e-09, + 3.8760e-01, 3.8760e-01, 1.9723e-01 + ]]], + dtype=dtype, + device=device) assert torch.allclose(output, expected_output, 1e-3, 1e-4) @@ -148,24 +148,16 @@ def torch_type_trans(dtype): return np.float64 -@pytest.mark.parametrize('dtype', [ - torch.half, - torch.float -]) +@pytest.mark.parametrize('dtype', [torch.half, torch.float]) @pytest.mark.parametrize('device', [ pytest.param( 'npu', marks=pytest.mark.skipif( not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) -@pytest.mark.parametrize('shape', [ - (2, 5, 6, 6), - (10, 10, 10, 10), - (20, 21, 13, 4), - (2, 10, 2, 18), - (10, 602, 910, 200), - (600, 100, 300, 101) -]) +@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10), + (20, 21, 13, 4), (2, 10, 2, 18), + (10, 602, 910, 200), (600, 100, 300, 101)]) def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): bs = shape[0] cs = shape[1] @@ -175,8 +167,9 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): features = np.random.uniform(-10.0, 10.0, (bs, cs, ms).astype(torch_type_trans(dtype))) idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32) - weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3) - ).astype(torch_type_trans(dtype)) + weight = np.random.uniform(-10.0, + 10.0 (bs, ns, + 3)).astype(torch_type_trans(dtype)) features_npu = torch.tensor(features, dtype=dtype).to(device) idx_npu = torch.tensor(idx, dtype=torch.int32).to(device) @@ -184,4 +177,4 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): expected_output = three_interpolate_forward_gloden(features, idx, weight) output = three_interpolate(features_npu, idx_npu, weight_npu) - assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4) + assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4) \ No newline at end of file