Skip to content

Commit

Permalink
[examples] add bidirectional LSTMs (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Aug 20, 2024
1 parent f6703c8 commit 7cc3d2e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 12 deletions.
12 changes: 8 additions & 4 deletions examples/converter/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@


class SimpleLSTM(nn.Module):
def __init__(self, in_dim, out_dim, layers, num_classes):
def __init__(self, in_dim, out_dim, layers, num_classes, bidirectional):
super(SimpleLSTM, self).__init__()
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers)
self.fc = torch.nn.Linear(out_dim, num_classes)
num_directions = 2 if bidirectional else 1
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(out_dim * num_directions, num_classes)
self.relu = torch.nn.ReLU()

def forward(self, inputs):
Expand All @@ -27,7 +28,7 @@ def forward(self, inputs):


def main_worker(args):
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes)
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes, args.bidirectional)

# Provide a viable input for the model
dummy_input = torch.rand((args.steps, args.batch_size, args.input_size))
Expand All @@ -53,6 +54,8 @@ def main_worker(args):
hybrid_asymmetric_inputs=True,
# Enable hybrid per-channel quantization for `Conv2d` and `DepthwiseConv2d`
hybrid_conv=True,
# Enable rewrite for BidirectionLSTMs to UnidirectionalLSTMs
map_bilstm_to_lstm=False,
)
converter.convert()

Expand All @@ -65,6 +68,7 @@ def main_worker(args):
parser.add_argument('--input-size', type=int, default=128)
parser.add_argument('--num-layers', type=int, default=1)
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--bidirectional', action='store_true')

args = parser.parse_args()
main_worker(args)
13 changes: 9 additions & 4 deletions examples/converter/dynamic_with_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@


class SimpleLSTM(nn.Module):
def __init__(self, in_dim, out_dim, layers, num_classes):
def __init__(self, in_dim, out_dim, layers, num_classes, bidirectional):
super(SimpleLSTM, self).__init__()
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers)
self.fc = torch.nn.Linear(out_dim, num_classes)
num_directions = 2 if bidirectional else 1
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(out_dim * num_directions, num_classes)
self.relu = torch.nn.ReLU()

def forward(self, inputs):
Expand Down Expand Up @@ -76,7 +77,7 @@ def benchmark_model_adb(path):


def main_worker(args):
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes)
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes, args.bidirectional)

# Provide a viable input for the model
dummy_input = torch.rand((args.steps, args.batch_size, args.input_size))
Expand Down Expand Up @@ -106,6 +107,8 @@ def main_worker(args):
hybrid_conv=True,
# Generate single op models for hybrid quantizable ops
hybrid_gen_single_op_models=True,
# Enable rewrite for BidirectionLSTMs to UnidirectionalLSTMs
map_bilstm_to_lstm=False,
)

converter.convert()
Expand Down Expand Up @@ -162,6 +165,8 @@ def main_worker(args):
hybrid_conv=converter.hybrid_conv,
# Hybrid configurations
hybrid_config=hybrid_config,
# Enable rewrite for BidirectionLSTMs to UnidirectionalLSTMs
map_bilstm_to_lstm=False,
)

converter.convert()
Expand Down
12 changes: 8 additions & 4 deletions examples/quantization/ptq_with_dynamic_q_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@


class SimpleLSTM(nn.Module):
def __init__(self, in_dim, out_dim, layers, num_classes):
def __init__(self, in_dim, out_dim, layers, num_classes, bidirectional):
super(SimpleLSTM, self).__init__()
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers)
self.fc = torch.nn.Linear(out_dim, num_classes)
num_directions = 2 if bidirectional else 1
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(out_dim * num_directions, num_classes)
self.relu = torch.nn.ReLU()

def forward(self, inputs):
Expand All @@ -28,7 +29,7 @@ def forward(self, inputs):


def main_worker(args):
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes)
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes, args.bidirectional)

# Provide a viable input for the model
dummy_input = torch.rand((args.steps, args.batch_size, args.input_size))
Expand Down Expand Up @@ -69,6 +70,8 @@ def main_worker(args):
hybrid_asymmetric_inputs=False,
# Enable int16 hybrid lstm quantization
hybrid_int16_lstm=True,
# Enable rewrite for BidirectionLSTMs to UnidirectionalLSTMs
map_bilstm_to_lstm=True,
)
converter.convert()

Expand All @@ -81,6 +84,7 @@ def main_worker(args):
parser.add_argument('--input-size', type=int, default=128)
parser.add_argument('--num-layers', type=int, default=1)
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--bidirectional', action='store_true')

args = parser.parse_args()
main_worker(args)
18 changes: 18 additions & 0 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,6 +2587,9 @@ def _is_rewritable_lstm_node(node, custom_data):
rewritable_lstm_nodes = graph.filter_forward_nodes(_is_rewritable_lstm_node)
fake_dequant_cls = torch_q.DeQuantStub
for idx, node in enumerate(rewritable_lstm_nodes):
assert node.module.num_layers == 1, "Quantization rewrite for multi-layer LSTM is not yet supported"
assert not node.module.bidirectional, "Quantization rewrite for bidirectional LSTM is not yet supported"

cell_state = node.next_tensors[1][1]

fake_dequant = fake_dequant_cls()
Expand Down Expand Up @@ -2617,6 +2620,21 @@ def _is_rewritable_lstm_node(node, custom_data):
)
size_len = len(node.next_tensors[0].shape)

if node.module.bidirectional:
with override_current_trace_graph(graph):
size_func = TraceFunction(
'torch.Tensor.__mul__', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(size_node.next_tensors[0], 2)

size_node = graph.insert_new_after(
size_node,
size_func,
[size_node.next_tensors[0]],
[None],
next_tensors=[size_node.next_tensors[0] * 2],
before_node=node.next_nodes[0],
)

with override_current_trace_graph(graph):
expand_func = TraceFunction(
'torch.Tensor.expand', is_class=True, prefix='fake_dequant_rewrite_'
Expand Down

0 comments on commit 7cc3d2e

Please sign in to comment.