Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lihao7212148 committed Oct 23, 2023
1 parent 9440105 commit 02b9fc8
Showing 1 changed file with 46 additions and 35 deletions.
81 changes: 46 additions & 35 deletions tests/test_ops/test_three_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
], [
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
]],
[[
8.1773e-01, 9.5440e-01, 2.4532e+00,
8.1773e-01, 8.1773e-01, 1.1359e+00
],
[
8.4689e-01, 1.9176e+00, 1.4715e+00,
8.4689e-01, 8.4689e-01, 1.3079e+00
],
[
6.9473e-01, 2.7440e-01, 2.0842e+00,
6.9473e-01, 6.9473e-01, 7.8619e-01
],
[
7.6789e-01, 1.5063e+00, 1.6209e+00,
7.6789e-01, 7.6789e-01, 1.1562e+00
],
[
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]],
dtype=dtype,
device=device)
[[
8.1773e-01, 9.5440e-01, 2.4532e+00,
8.1773e-01, 8.1773e-01, 1.1359e+00
],
[
8.4689e-01, 1.9176e+00, 1.4715e+00,
8.4689e-01, 8.4689e-01, 1.3079e+00
],
[
6.9473e-01, 2.7440e-01, 2.0842e+00,
6.9473e-01, 6.9473e-01, 7.8619e-01
],
[
7.6789e-01, 1.5063e+00, 1.6209e+00,
7.6789e-01, 7.6789e-01, 1.1562e+00
],
[
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]],
dtype=dtype,
device=device)

assert torch.allclose(output, expected_output, 1e-3, 1e-4)

Expand All @@ -106,16 +106,17 @@ def three_interpolate_forward_gloden(features, idx, weight):
if dtype == np.float16:
features = features.astype(np.float32)
weight = weight.astype(np.float32)
output = np.zeros((bs, cs, ns), dtype = np.float)

output = np.zeros((bs, cs, ns), dtype=np.float)
for b in range(bs):
for c in range(cs):
for n in range(ns):
output[b][c][n] = features[b][c][idx[b][n][0]] * weight[b][n][0] \
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
return output


def three_interpolate_backward_gloden(grad_output, idx, weight, features):
bs, cs, ns = grad_output.shape
ms = features.shape[2]
Expand All @@ -124,7 +125,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
if dtype == np.float16:
features = features.astype(np.float32)
weight = weight.astype(np.float32)

grad_point = np.zeros((bs, cs, ms), dtype=np.float)
for b in range(bs):
for c in range(cs):
Expand All @@ -137,6 +138,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
grad_output[b][c][n] * weight[b][n][2]
return grad_point


def torch_type_trans(dtype):
if dtype == torch.half:
return np.float16
Expand All @@ -145,27 +147,36 @@ def torch_type_trans(dtype):
else:
return np.float64


@pytest.mark.parametrize('dtype', [
torch.half,
torch.float
])
@pytest.mark.parametrize('device', [
(2,5,6,6),
(10,10,10,10),
(20,21,13,4),
(2,10,2,18),
(10,602,910,200),
(600,100,300,101)
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('shape', [
(2, 5, 6, 6),
(10, 10, 10, 10),
(20, 21, 13, 4),
(2, 10, 2, 18),
(10, 602, 910, 200),
(600, 100, 300, 101)
])
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
cs = shape[1]
ms = shape[2]
ns = shape[3]

features = np.random.uniform(-10.0, 10.0, (bs, cs, ms).astype(torch_type_trans(dtype)))
features = np.random.uniform(-10.0, 10.0,
(bs, cs, ms).astype(torch_type_trans(dtype)))
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)).astype(torch_type_trans(dtype))
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)
).astype(torch_type_trans(dtype))

features_npu = torch.tensor(features, dtype=dtype).to(device)
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)
Expand Down

0 comments on commit 02b9fc8

Please sign in to comment.