Skip to content

Commit

Permalink
Merge pull request opencv#24655 from fengyuentau:graph_simplifier_opt…
Browse files Browse the repository at this point in the history
…ional_input

dnn onnx graph simplifier: handle optional inputs of Slice opencv#24655

Resolves opencv#24609

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
fengyuentau authored Dec 6, 2023
1 parent 22edfd2 commit f5ec92e
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion modules/dnn/src/onnx/onnx_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ class ONNXGraphWrapper : public ImportGraphWrapper
return makePtr<ONNXNodeWrapper>(node);
}

int getTensorShapeSize(int node_id, int node_input_id) {
const auto node = getNode(node_id);
const auto &input_name = node->getInputName(node_input_id);
for (int i = 0; i < net.value_info_size(); i++) {
const auto value_info = net.value_info(i);
if (value_info.name() == input_name) {
if (value_info.has_type() && value_info.type().has_tensor_type() &&
value_info.type().tensor_type().has_shape()) {
return value_info.type().tensor_type().shape().dim_size();
} else {
return -1;
}
}
}
return -1;
}

int getInputInitializerId(int node_id, int node_input_id)
{
auto node = getNode(node_id);
Expand Down Expand Up @@ -164,6 +181,61 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
}
}

/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph removes
all optional inputs of Slice if values are default.
*/
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
public:
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
num_inputs_ = num_inputs;

int input = addNodeToMatch("");
int starts = addNodeToMatch("");
int ends = addNodeToMatch("");
std::vector<int> inputs{input, starts, ends};
for (size_t i = 3; i < num_inputs_; i++) { // axes and steps
inputs.push_back(addNodeToMatch(""));
}

slice_id = addNodeToMatch("Slice", inputs);

setFusedNode("Slice", std::vector<int>{input, starts, ends});
}

virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
if (num_inputs_ >= 4) { // with axes
// Check if axes are -1 or last axis
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);

auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
for (size_t i = 0; i < axes.total(); i++) {
const int axis = *(axes.ptr<const int>() + i);
if (axis != -1 && axis != shape_size - 1) {
return false;
}
}
}
if (num_inputs_ == 5) {
// Check if steps are 1
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
if (countNonZero(steps != 1)) {
return false;
}
}
return true;
}
return false;
}

private:
int slice_id;
size_t num_inputs_;
};

/* Fusion for Gelu.
Graph before fusion:
Expand Down Expand Up @@ -1091,7 +1163,7 @@ class ResizeSubgraph3 : public Subgraph
int cast = addNodeToMatch("Cast", concat1);

int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);

Expand Down Expand Up @@ -1163,6 +1235,8 @@ class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
subgraphs.push_back(makePtr<GeluSubGraph>());
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
subgraphs.push_back(makePtr<LayerNormSubGraph>());
Expand Down

0 comments on commit f5ec92e

Please sign in to comment.