Skip to content

Commit

Permalink
Merge pull request opencv#24672 from dkurt:adjust_slice_optional_inputs
Browse files Browse the repository at this point in the history
Replace Slice optional inputs removal to adjustment
  • Loading branch information
asmorkalov authored Dec 10, 2023
2 parents 8b577ab + ac4b26a commit 098efb6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
58 changes: 24 additions & 34 deletions modules/dnn/src/onnx/onnx_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
}

/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph removes
all optional inputs of Slice if values are default.
Slice with optional inputs of default values, some of them don't. This Subgraph adjusts
all optional inputs of Slice up to 5.
*/
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
class AdjustSliceAllOptionalInputsSubgraph : public Subgraph {
public:
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
num_inputs_ = num_inputs;

int input = addNodeToMatch("");
Expand All @@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {

slice_id = addNodeToMatch("Slice", inputs);

setFusedNode("Slice", std::vector<int>{input, starts, ends});
setFusedNode("Slice", inputs);
}

virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
if (num_inputs_ >= 4) { // with axes
// Check if axes are -1 or last axis
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);

auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
for (size_t i = 0; i < axes.total(); i++) {
const int axis = *(axes.ptr<const int>() + i);
if (axis != -1 && axis != shape_size - 1) {
return false;
}
}
}
if (num_inputs_ == 5) {
// Check if steps are 1
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
if (countNonZero(steps != 1)) {
return false;
}
}
return true;
virtual void finalize(const Ptr<ImportGraphWrapper>&,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
{
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = num_inputs_; i < 5; ++i) {
node->add_input("");
}
return false;
}

private:
Expand Down Expand Up @@ -1119,7 +1101,11 @@ class ResizeSubgraph1 : public ExtractScalesSubgraph
ResizeSubgraph1() : ExtractScalesSubgraph()
{
int shape = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", {shape,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});

int castConcat = addNodeToMatch("Cast", concatId);
int concat = addNodeToMatch("Concat", slice, castConcat);
Expand Down Expand Up @@ -1163,7 +1149,11 @@ class ResizeSubgraph3 : public Subgraph
int cast = addNodeToMatch("Cast", concat1);

int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", {shape2,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);

Expand Down Expand Up @@ -1235,8 +1225,8 @@ class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(3));
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<GeluSubGraph>());
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
subgraphs.push_back(makePtr<LayerNormSubGraph>());
Expand Down
4 changes: 2 additions & 2 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());

if (inp_size > 3)
if (inp_size > 3 && !getBlob(node_proto, 3).empty())
{
Mat axes_blob = getBlob(node_proto, 3);
CV_Assert(axes_blob.total() == start_blob.total());
Expand All @@ -1244,7 +1244,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
has_axes = true;
}

if (inp_size == 5)
if (inp_size == 5 && !getBlob(node_proto, 4).empty())
{
Mat step_blob = getBlob(node_proto, 4);
CV_Assert(step_blob.total() == start_blob.total());
Expand Down

0 comments on commit 098efb6

Please sign in to comment.