Skip to content

Commit

Permalink
[converter] fix tensor check constraint (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Jun 19, 2024
1 parent 0f9e3d5 commit 8a7ff2e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
22 changes: 22 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,28 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_add_param(self):
dummy_input = torch.randn(9, 17, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = nn.Parameter(torch.tensor(2.0))

def forward(self, x):
return x + self.param

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_addmm(self):
dummy_input = torch.randn(9, 17, dtype=torch.float32)
mat = torch.randn(17, 22, dtype=torch.float32)
Expand Down
34 changes: 17 additions & 17 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ def parse(self, node, attrs, args, graph_converter):

if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif type(other) != torch.Tensor:
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::sub(input, other)"

self.elementwise_binary(tfl.SubOperator, graph_converter, True)
Expand All @@ -1230,7 +1230,7 @@ def parse(self, node, attrs, args, graph_converter):

if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif type(other) != torch.Tensor:
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::rsub(input, other)"

# Swap the first two input tensors and their names
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def parse(self, node, attrs, args, graph_converter):

if type(other) in (int, float):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif type(other) != torch.Tensor:
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::mul(input, other)"

self.elementwise_binary(tfl.MulOperator, graph_converter, True)
Expand Down Expand Up @@ -1338,7 +1338,7 @@ def parse(self, node, attrs, args, graph_converter):
other = self.input_tensors[1]
if type(other) in (int, float):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif type(other) != torch.Tensor:
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::div(input, other)"

self.elementwise_binary(tfl.DivOperator, graph_converter, True)
Expand All @@ -1362,7 +1362,7 @@ def parse(self, node, attrs, args, graph_converter):
torch.int32,
), "Input should be tensors of type torch.float32 or torch.int32"

if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)

self.elementwise_binary(tfl.PowOperator, graph_converter, True)
Expand Down Expand Up @@ -1738,7 +1738,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)
elif self.input_tensors[1].dtype != self.input_tensors[0].dtype:
other = self.find_or_create_input(1, graph_converter)
Expand Down Expand Up @@ -2028,7 +2028,7 @@ def parse(self, node, attrs, args, graph_converter):

if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif type(other) != torch.Tensor:
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::add(input, other)"

self.elementwise_binary(tfl.AddOperator, graph_converter, True)
Expand Down Expand Up @@ -2269,7 +2269,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.NotEqualOperator, graph_converter, True)
Expand Down Expand Up @@ -3636,7 +3636,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.EqualOperator, graph_converter, True)
Expand Down Expand Up @@ -3910,7 +3910,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.MaximumOperator, graph_converter, True)
Expand All @@ -3921,7 +3921,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.MinimumOperator, graph_converter, True)
Expand All @@ -3932,7 +3932,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.GreaterOperator, graph_converter, True)
Expand All @@ -3943,7 +3943,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.LessOperator, graph_converter, True)
Expand All @@ -3954,7 +3954,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.GreaterEqualOperator, graph_converter, np.True_)
Expand All @@ -3965,7 +3965,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])

self.elementwise_binary(tfl.LessEqualOperator, graph_converter, True)
Expand All @@ -3976,7 +3976,7 @@ def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
if type(self.input_tensors[1]) != torch.Tensor:
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)

self.elementwise_binary(tfl.FloorModOperator, graph_converter, True)
Expand All @@ -3989,7 +3989,7 @@ def parse(self, node, attrs, args, graph_converter):
self.run(node)
assert 'self' in args and 'other' in args, "aten::where(condition) is not supported"

if type(self.input_tensors[2]) != torch.Tensor:
if not isinstance(self.input_tensors[2], torch.Tensor):
self.input_tensors[2] = torch.tensor([self.input_tensors[2]])

ATenMaskedFillOperator.parse_common(self, graph_converter, input_idx=2, mask_idx=0, other_idx=1)
Expand Down
2 changes: 1 addition & 1 deletion tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def quantize_scalar_tensor(self, tensor: torch.Tensor):

def torch_tensor_from_scalar(self, ref_tensor: torch.Tensor, src_tensor: torch.Tensor):
tgt_tensor = src_tensor
if type(src_tensor) != torch.Tensor:
if not isinstance(src_tensor, torch.Tensor):
if ref_tensor.is_quantized:
tgt_tensor = torch.quantize_per_tensor(
torch.tensor([src_tensor], dtype=torch.float32),
Expand Down

0 comments on commit 8a7ff2e

Please sign in to comment.