Skip to content

Commit

Permalink
[Converter] Support GRU operator conversion with separated_rnn_gate_c…
Browse files Browse the repository at this point in the history
…alc=False (#323)

* Update aten.py

* Update aten.py

* Update converter_op_test.py

* Update aten.py

* minor fix

* minor fix

* Update converter_op_test.py
  • Loading branch information
Juelianqvq authored Jun 3, 2024
1 parent 1e74c0a commit 7de51e6
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ You may also try out static quantization for LSTMs when you have PyTorch 1.13+.
#### What if my model runs slower when dynamic quantization is enabled?
Please refer to [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) for selective dynamic quantization.

#### I need LSTMs with separated gate calculation when `unroll_rnn=True`.
#### I need LSTM/GRUs with separated gate calculation when `unroll_rnn=True`.
Please set `separated_rnn_gate_calc=True`.

#### How to add state inputs and outputs for LSTMs/GRUs/RNNs with `unroll_rnn=True`?
Expand Down
2 changes: 1 addition & 1 deletion docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Note: 这些状态变量都是二维的,维度为`[batch_size, hidden_size或
#### 我的模型开了动态量化变得更慢了?
请参考 [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) 选择性的开启动态量化。

#### 在设置了`unroll_rnn=True`后,LSTM中多个门的计算被融合了。有没有办法分开?
#### 在设置了`unroll_rnn=True`后,LSTM/GRU中多个门的计算被融合了。有没有办法分开?
尝试设置`separated_rnn_gate_calc=True`

#### `unroll_rnn=True`的情况下,怎么为包含LSTM、RNN和GRU的网络添加状态输入输出?
Expand Down
176 changes: 176 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2952,6 +2952,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_batch_first_unroll_separated(self):
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

Expand All @@ -2976,6 +3000,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_gru_batch_first_unroll_unseparated(self):
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, batch_first=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_gru_with_state_tensor_unroll_separated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
Expand Down Expand Up @@ -3004,6 +3052,34 @@ def forward(self, x, hx):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
torch.randn(1, 1, 20, dtype=torch.float32),
]

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20)

def forward(self, x, hx):
gru, hx = self.gru(x, hx)
return gru, hx

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_unroll_separated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand All @@ -3028,6 +3104,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_with_state_tensor_unroll_separated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
Expand Down Expand Up @@ -3056,6 +3156,34 @@ def forward(self, x, hx):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
torch.randn(2, 1, 20, dtype=torch.float32),
]

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2)

def forward(self, x, hx):
gru, hx = self.gru(x, hx)
return gru, hx

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down Expand Up @@ -3278,6 +3406,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, bidirectional=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_multi_layer_unroll_separated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand All @@ -3302,6 +3454,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2, bidirectional=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down
Loading

0 comments on commit 7de51e6

Please sign in to comment.