Skip to content

Commit

Permalink
segment-anything init.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Sep 27, 2023
0 parents commit 4fed07e
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 0 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# mnn-segment-anything

## Model
| model | onnx | mnn |
|:---------:|:------:|:------:|
| sam_vit_b | `TODO` | `TODO` |

- [github](https://github.com/facebookresearch/segment-anything)

## Demo
- [Python](./python/)
- [C++](./cpp)
23 changes: 23 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
cmake_minimum_required(VERSION 3.0)
project(sam)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# include dir
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/)

# libs dir
link_directories(${CMAKE_CURRENT_LIST_DIR}/libs)

# source files
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/*.cpp)

# target
add_executable(sam_demo ${SRCS})

# link
if (MSVC)
target_link_libraries(sam_demo MNN)
else()
target_link_libraries(sam_demo MNN MNN_Express MNNOpenCV)
endif()
51 changes: 51 additions & 0 deletions cpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Usage

## Compile MNN library
### Linx/Mac
```bash
git clone https://github.com/alibaba/MNN.git
# copy header file
cp -r MNN/include .
cp -r MNN/tools/cv/include .
cd MNN
mkdir build
cmake -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON ..
make -j8
cd ..
cp MNN/build/libMNN.so MNN/build/express/libMNN_Express.so MNN/build/tools/cv/libMNNOpenCV.so ./libs
```

### Windows
```bash
# Visual Studio xxxx Developer Command Prompt
powershell
git clone https://github.com/alibaba/MNN.git
# copy header file
cp -r MNN/include .
cp -r MNN/tools/cv/include .
cd MNN
mkdir build
cmake -G "Ninja" -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON ..
ninja
cd ..
cp MNN.dll MNN.lib ./libs
```

## Build and Run

#### Linux/Mac
```bash
mkdir build && cd build
cmake ..
make -j4
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg
```
#### Windows
```bash
# Visual Studio xxxx Developer Command Prompt
powershell
mkdir build && cd build
cmake -G "Ninja" ..
ninja
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg
```
Empty file added cpp/include/.gitkeep
Empty file.
Empty file added cpp/libs/.gitkeep
Empty file.
112 changes: 112 additions & 0 deletions cpp/sam_demo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <stdio.h>
#include <MNN/ImageProcess.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/Executor.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Executor.hpp>

#include <cv/cv.hpp>

using namespace MNN;
using namespace MNN::Express;
using namespace MNN::CV;

int main(int argc, const char* argv[]) {
if (argc < 4) {
MNN_PRINT("Usage: ./sam_demo.out embed.mnn sam.mnn input.jpg [forwardType] [precision] [thread]\n");
return 0;
}
int thread = 4;
int precision = 0;
int forwardType = MNN_FORWARD_CPU;
if (argc >= 5) {
forwardType = atoi(argv[4]);
}
if (argc >= 6) {
precision = atoi(argv[5]);
}
if (argc >= 7) {
thread = atoi(argv[6]);
}
float mask_threshold = 0;
MNN::ScheduleConfig sConfig;
sConfig.type = static_cast<MNNForwardType>(forwardType);
sConfig.numThread = thread;
BackendConfig bConfig;
bConfig.precision = static_cast<BackendConfig::PrecisionMode>(precision);
sConfig.backendConfig = &bConfig;
std::shared_ptr<Executor::RuntimeManager> rtmgr = std::shared_ptr<Executor::RuntimeManager>(Executor::RuntimeManager::createRuntimeManager(sConfig));
if(rtmgr == nullptr) {
MNN_ERROR("Empty RuntimeManger\n");
return 0;
}
// rtmgr->setCache(".cachefile");
std::shared_ptr<Module> embed(Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr));
std::shared_ptr<Module> sam(Module::load(
{"point_coords", "point_labels", "image_embeddings", "has_mask_input", "mask_input", "orig_im_size"},
{"iou_predictions", "low_res_masks", "masks"}, argv[2], rtmgr));
auto image = imread(argv[3]);
// 1. preprocess
auto dims = image->getInfo()->dim;
int origin_h = dims[0];
int origin_w = dims[1];
int length = 1024;
int new_h, new_w;
if (origin_h > origin_w) {
new_w = round(origin_w * (float)length / origin_h);
new_h = length;
} else {
new_h = round(origin_h * (float)length / origin_w);
new_w = length;
}
float scale_w = (float)new_w / origin_w;
float scale_h = (float)new_h / origin_h;
auto input_var = resize(image, Size(new_w, new_h), 0, 0, INTER_LINEAR, -1, {123.675, 116.28, 103.53}, {1/58.395, 1/57.12, 1/57.375});
std::vector<int> padvals { 0, length - new_h, 0, length - new_w, 0, 0 };
auto pads = _Const(static_cast<void*>(padvals.data()), {3, 2}, NCHW, halide_type_of<int>());
input_var = _Pad(input_var, pads, CONSTANT);
input_var = _Unsqueeze(input_var, {0});
// 2. image embedding
input_var = _Convert(input_var, NC4HW4);
auto outputs = embed->onForward({input_var});
auto image_embedding = _Convert(outputs[0], NCHW);

// 3. segment
auto build_input = [](std::vector<float> data, std::vector<int> shape) {
return _Const(static_cast<void*>(data.data()), shape, NCHW, halide_type_of<float>());
};
// build inputs
std::vector<float> points = {500, 375};
auto scale_points = points;
for (int i = 0; i < scale_points.size() / 2; i++) {
scale_points[2 * i] = scale_points[2 * i] * scale_w;
scale_points[2 * i + 1] = scale_points[2 * i + 1] * scale_h;
}
scale_points.push_back(0);
scale_points.push_back(0);
auto point_coords = build_input(scale_points, {1, 2, 2});
auto point_labels = build_input({1, -1}, {1, 2});
auto orig_im_size = build_input({static_cast<float>(origin_h), static_cast<float>(origin_w)}, {2});
auto has_mask_input = build_input({0}, {1});
std::vector<float> zeros(256*256, 0.f);
auto mask_input = build_input(zeros, {1, 1, 256, 256});
auto output_vars = sam->onForward({point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size});
auto masks = _Convert(output_vars[2], NCHW);
// 4. postprocess: draw mask and point
masks = _Greater(masks, _Scalar(mask_threshold));
masks = _Reshape(masks, {origin_h, origin_w, 1});
std::vector<int> color_vec {30, 144, 255};
auto color = _Const(static_cast<void*>(color_vec.data()), {1, 1, 3}, NCHW, halide_type_of<int>());
image = _Cast<uint8_t>(_Cast<int>(image) + masks * color);
auto ptr = image->readMap<uint8_t>();
for (int i = 0; i < points.size() / 2; i++) {
float x = points[2 * i];
float y = points[2 * i + 1];
circle(image, {x, y}, 10, {0, 0, 255}, 5);
}
if (imwrite("res.jpg", image)) {
MNN_PRINT("result image write to `res.jpg`.\n");
}
// rtmgr->updateCache();
return 0;
}
11 changes: 11 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Usage

## Install MNN
```
pip install MNN
```

## Run Demo
```
python segment_anything_example.py --embed embed.mnn --sam segment.mnn --img ../resource/truck.jpg
```
68 changes: 68 additions & 0 deletions python/segment_anything_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#-- coding:utf8 --
import argparse

import MNN
import MNN.numpy as np
import MNN.cv as cv2

def inference(emed, sam, img, precision, backend, thread):
mask_threshold = 0.0
config = {}
config['precision'] = precision
config['backend'] = backend
config['numThread'] = thread
rt = MNN.nn.create_runtime_manager((config,))
embed = MNN.nn.load_module_from_file(emed, [], [], runtime_manager=rt)
sam = MNN.nn.load_module_from_file(sam,
['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size'],
['iou_predictions', 'low_res_masks', 'masks'], runtime_manager=rt)
image = cv2.imread(img)
origin_h, origin_w, _ = image.shape
length = 1024
if origin_h > origin_w:
new_w = round(origin_w * float(length) / origin_h)
new_h = length
else:
new_h = round(origin_h * float(length) / origin_w)
new_w = length
scale_w = new_w / origin_w
sclae_h = new_h / origin_h
input_var = cv2.resize(image, (new_w, new_h), 0., 0., cv2.INTER_LINEAR, -1, [123.675, 116.28, 103.53], [1/58.395, 1/57.12, 1/57.375])
input_var = np.pad(input_var, [[0, length - new_h], [0, length - new_w], [0, 0]], 'constant')
input_var = np.expand_dims(input_var, 0)
input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4)
output_var = embed.forward(input_var)
image_embedding = MNN.expr.convert(output_var, MNN.expr.NCHW)
points = [[500, 375]]
sclaes = [scale_w, sclae_h]
input_point = np.array(points) * sclaes
input_label = np.array([1])
point_coords = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
point_labels = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)

print('point_coords: ', point_coords, point_coords.shape)
print('point_labels: ', point_labels, point_labels.shape)
print('orig_im_size: ', orig_im_size, orig_im_size.shape)
output_vars = sam.onForward([point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size])
masks = MNN.expr.convert(output_vars[2], MNN.expr.NCHW)
masks = (masks > mask_threshold).reshape([origin_h, origin_w, 1])
# draw masks and point
color = np.array([30, 144, 255]).reshape([1, 1, -1])
image = (image + masks * color).astype(np.uint8)
for point in points:
cv2.circle(image, point, 10, (0, 0, 255), 5)
cv2.imwrite('res.jpg', image)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--embed', type=str, required=True, help='the embedding model path')
parser.add_argument('--sam', type=str, required=True, help='the sam model path')
parser.add_argument('--img', type=str, required=True, help='the input image path')
parser.add_argument('--precision', type=str, default='normal', help='inference precision: normal, low, high, lowBF')
parser.add_argument('--backend', type=str, default='CPU', help='inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI')
parser.add_argument('--thread', type=int, default=4, help='inference using thread: int')
args = parser.parse_args()
inference(args.embed, args.sam, args.img, args.precision, args.backend, args.thread)

0 comments on commit 4fed07e

Please sign in to comment.