Skip to content

Commit

Permalink
apply fixes for ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Oct 24, 2024
1 parent fdd7093 commit e9d21db
Show file tree
Hide file tree
Showing 21 changed files with 124 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ repos:
- id: black
exclude: ^tinynn/converter/schemas
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.150
rev: v0.7.0
hooks:
- id: ruff
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ target-version = "py39"
"examples/*.py" = ["E402"]
"__init__.py" = ["F401", "F403"]
"tests/import_test.py" = ["F401"]
"tutorials/quantization/basic.ipynb" = ["F811", "F401"]

[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy>=1.18.5,<2; python_version < '3.10'
numpy>=1.18.5; python_version >= '3.10'
numpy>=1.18.5
PyYAML>=5.3.1
ruamel.yaml>=0.16.12
igraph>=0.9
Expand Down
4 changes: 2 additions & 2 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def test_reduce_ops_single_dim(self):

def model(x):
res = func(x, dim=1)
return res if type(res) == torch.Tensor else res[0]
return res if type(res) is torch.Tensor else res[0]

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_reduce_ops_single_dim_keepdim(self):

def model(x):
res = func(x, dim=1, keepdim=True)
return res if type(res) == torch.Tensor else res[0]
return res if type(res) is torch.Tensor else res[0]

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
Expand Down
6 changes: 3 additions & 3 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4661,21 +4661,21 @@ def elinimate_sequences(
first_node = seq[0]
last_node = seq[-1]

if type(skip_pred) == bool:
if type(skip_pred) is bool:
skip = skip_pred
elif skip_pred is not None:
skip = skip_pred(seq)

if skip:
continue

if type(remove_first_pred) == bool:
if type(remove_first_pred) is bool:
remove_first = remove_first_pred
custom_data = None
elif remove_first_pred is not None:
remove_first, custom_data = remove_first_pred(seq)

if type(remove_last_pred) == bool:
if type(remove_last_pred) is bool:
remove_last = remove_last_pred
custom_data_last = None
elif remove_last_pred is not None:
Expand Down
10 changes: 5 additions & 5 deletions tinynn/converter/operators/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.index = 0
self.is_variable = is_variable

if type(tensor) == FakeQuantTensor:
if type(tensor) is FakeQuantTensor:
self.quantization = QuantizationParameters(tensor.scale, tensor.zero_point, tensor.dim)
tensor = tensor.tensor

Expand All @@ -195,7 +195,7 @@ def __init__(

if type(tensor).__module__ == 'numpy':
self.tensor = tensor
elif type(tensor) == torch.Tensor:
elif type(tensor) is torch.Tensor:
assert tensor.is_contiguous, "Tensor should be contiguous"
if tensor.dtype == torch.quint8:
self.tensor = torch.int_repr(tensor.detach()).numpy()
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(
self.quantization = QuantizationParameters(scales, zero_points, dim)
else:
self.tensor = tensor.detach().numpy()
elif type(tensor) == torch.Size:
elif type(tensor) is torch.Size:
self.tensor = np.asarray(tensor, dtype='int32')
elif type(tensor) in (tuple, list):
self.tensor = np.asarray(tensor, dtype=dtype)
Expand Down Expand Up @@ -390,7 +390,7 @@ def build(self, builder: flatbuffers.Builder) -> Offset:
def create_offset_vector(builder: flatbuffers.Builder, prop: typing.Callable, vec: typing.Iterable):
if type(vec) not in (tuple, list):
assert False, "type of vec unexpected, expected: list or tuple"
elif type(vec) == tuple:
elif type(vec) is tuple:
vec = list(vec)

prop_name = prop.__name__
Expand Down Expand Up @@ -426,7 +426,7 @@ def create_numpy_array(builder: flatbuffers.Builder, prop: typing.Callable, vec:


def create_string(builder: flatbuffers.Builder, prop: typing.Callable, val: str):
if type(val) != str:
if type(val) is not str:
assert False, "type of val unexpected, expected: str"

prop_name = prop.__name__
Expand Down
18 changes: 9 additions & 9 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@ def parse(self, node, attrs, args, graph_converter):
self.run(node)

dim = self.input_tensors[1]
assert type(dim) == int
assert type(dim) is int

if dim < 0:
dim += self.input_tensors[0][0].ndim + 1
Expand Down Expand Up @@ -1619,7 +1619,7 @@ def parse(self, node, attrs, args, graph_converter):
self.run(node)

dim = self.input_tensors[1]
assert type(dim) == int
assert type(dim) is int

if dim < 0:
dim += self.input_tensors[0][0].ndim
Expand Down Expand Up @@ -2067,8 +2067,8 @@ def parse(self, node, attrs, args, graph_converter):
input_tensor = self.find_or_create_input(0, graph_converter)
dim, index = self.input_tensors[1:]

assert type(dim) == int
assert type(index) == int
assert type(dim) is int
assert type(index) is int

if dim < 0:
dim += input_tensor.tensor.ndim
Expand Down Expand Up @@ -2166,11 +2166,11 @@ def parse(self, node, attrs, args, graph_converter):
self.parse_common(node, attrs, args, graph_converter)

def parse_common(self, node, attrs, args, graph_converter):
if type(self) == ATenClampOperator:
if type(self) is ATenClampOperator:
min_value, max_value = self.input_tensors[1:]
elif type(self) == ATenClampMinOperator:
elif type(self) is ATenClampMinOperator:
min_value, max_value = self.input_tensors[1], None
elif type(self) == ATenClampMaxOperator:
elif type(self) is ATenClampMaxOperator:
min_value, max_value = None, self.input_tensors[1]

has_min = min_value is not None
Expand Down Expand Up @@ -3808,7 +3808,7 @@ def parse(self, node, attrs, args, graph_converter):
def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, out_idx=0):
for i in (input_idx, other_idx):
t = self.input_tensors[i]
if type(t) == torch.Tensor:
if type(t) is torch.Tensor:
if t.dtype == torch.float64:
self.input_tensors[i] = t.to(dtype=torch.float32)
elif t.dtype == torch.int64:
Expand All @@ -3826,7 +3826,7 @@ def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, ou
input_tensor, mask_tensor = [self.find_or_create_input(i, graph_converter) for i in (input_idx, mask_idx)]

ops = []
if type(other) == torch.Tensor:
if type(other) is torch.Tensor:
other_t = self.find_or_create_input(other_idx, graph_converter)
if out.dtype != other.dtype:
casted = other.clone().to(dtype=out.dtype)
Expand Down
6 changes: 3 additions & 3 deletions tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def to_tfl_tensors(
tfl_tensors = []
if has_buffers is None:
has_buffers = [None] * len(tensors)
elif type(has_buffers) == bool:
elif type(has_buffers) is bool:
has_buffers = [has_buffers] * len(tensors)
assert len(names) == len(tensors) == len(has_buffers)
for n, t, b in zip(names, tensors, has_buffers):
Expand Down Expand Up @@ -491,7 +491,7 @@ def handle_padding(self, pad_h, pad_w, pad_op_index, ops, ceil_mode=False):
input_size = [input_tensor.shape[2], input_tensor.shape[3]]

if not all((i + 2 * p - k) % s == 0 for i, p, k, s in zip(input_size, padding, kernel_size, stride)):
assert type(ops[1]) == tfl.MaxPool2dOperator, 'ceil_mode=True for AvgPool not supported'
assert type(ops[1]) is tfl.MaxPool2dOperator, 'ceil_mode=True for AvgPool not supported'
fill_nan = True
ceil_pad = get_pool_ceil_padding(input_tensor, kernel_size, stride, padding)
ceil_pad = list(np.add(ceil_pad, padding))
Expand All @@ -503,7 +503,7 @@ def handle_padding(self, pad_h, pad_w, pad_op_index, ops, ceil_mode=False):
pad_input = ops[pad_op_index - 1].outputs[0]

inputs = [pad_input, pad_tensor]
if type(ops[1]) == tfl.MaxPool2dOperator:
if type(ops[1]) is tfl.MaxPool2dOperator:
constant_tensor = self.get_minimum_constant(pad_input)
inputs.append(constant_tensor)
pad_array = np.pad(pad_input.tensor, pad, constant_values=constant_tensor.tensor[0])
Expand Down
2 changes: 1 addition & 1 deletion tinynn/converter/operators/torch/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def parse(self, node, attrs, args, graph_converter):
self.run(node)

dim = self.input_tensors[1]
assert type(dim) == int
assert type(dim) is int

if dim < 0:
dim += self.input_tensors[0][0].ndim
Expand Down
2 changes: 1 addition & 1 deletion tinynn/graph/configs/gen_creation_funcs_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if k in block_list:
continue
c = getattr(torch, k)
if inspect.isclass(c) and k.endswith('Tensor') and c.__bases__[0] == object:
if inspect.isclass(c) and k.endswith('Tensor') and c.__bases__[0] is object:
print(k)
final_dict['torch'].append(k)
elif inspect.isbuiltin(c):
Expand Down
10 changes: 5 additions & 5 deletions tinynn/graph/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def apply_mask(self, modifiers):
args_parsed = self.node.module.args_parsed_origin

if len(args_parsed) > 1:
if type(args_parsed[1]) == list:
if type(args_parsed[1]) is list:
ch = [int(i) for i in args_parsed[1]]
ch_new = []

Expand Down Expand Up @@ -2556,7 +2556,7 @@ def register_mask(self, modifiers, importance, sparsity):

def create_channel_modifier(n):
for key in CHANNEL_MODIFIERS.keys():
if type(key) == str:
if type(key) is str:
if n.kind() == key:
return CHANNEL_MODIFIERS[key](n)
elif isinstance(n.module, key):
Expand Down Expand Up @@ -2611,7 +2611,7 @@ def calc_prune_idx_by_bn_variance(
ignored_bn = set()

for leaf in self.leaf:
if type(leaf.module()) != nn.BatchNorm2d:
if type(leaf.module()) is not nn.BatchNorm2d:
continue

while True:
Expand All @@ -2621,15 +2621,15 @@ def calc_prune_idx_by_bn_variance(
break

if leaf in self.leaf:
if type(leaf.module()) != nn.BatchNorm2d:
if type(leaf.module()) is not nn.BatchNorm2d:
continue

ignored_bn.add(leaf)

for leaf in self.leaf:
if leaf in ignored_bn:
continue
if type(leaf.module()) != nn.BatchNorm2d:
if type(leaf.module()) is not nn.BatchNorm2d:
continue

is_real_leaf = True
Expand Down
2 changes: 1 addition & 1 deletion tinynn/graph/quantization/fused_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class ConvTransposeBn2d(_FusedModule):
def __init__(self, conv, bn):
assert (
type(conv) == nn.ConvTranspose2d and type(bn) == nn.BatchNorm2d
type(conv) is nn.ConvTranspose2d and type(bn) is nn.BatchNorm2d
), 'Incorrect types for input modules{}{}'.format(type(conv), type(bn))
super(ConvTransposeBn2d, self).__init__(conv, bn)

Expand Down
10 changes: 5 additions & 5 deletions tinynn/graph/quantization/qat_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def from_float(cls, mod):
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
assert type(mod) is cls._FLOAT_MODULE, (
'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if type(mod) == ConvReLU1d:
if type(mod) is ConvReLU1d:
mod = mod[0]
qconfig = mod.qconfig
qat_conv = cls(
Expand Down Expand Up @@ -224,7 +224,7 @@ def from_float(cls, mod):
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
assert type(mod) is cls._FLOAT_MODULE, (
'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
Expand Down Expand Up @@ -335,7 +335,7 @@ def from_float(cls, mod):
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
assert type(mod) is cls._FLOAT_MODULE, (
'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
Expand Down Expand Up @@ -566,7 +566,7 @@ def from_float(cls, mod):
"""
# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
# has no __name__ (code is fine though)
assert type(mod) == cls._FLOAT_MODULE, (
assert type(mod) is cls._FLOAT_MODULE, (
'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__
) # type: ignore[attr-defined]
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
Expand Down
2 changes: 1 addition & 1 deletion tinynn/graph/quantization/quantizable/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def from_params(cls, wi, wh, bi=None, bh=None):

@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert type(other) is cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh)
observed.qconfig = other.qconfig
Expand Down
Loading

0 comments on commit e9d21db

Please sign in to comment.