-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4fed07e
Showing
8 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# mnn-segment-anything | ||
|
||
## Model | ||
| model | onnx | mnn | | ||
|:---------:|:------:|:------:| | ||
| sam_vit_b | `TODO` | `TODO` | | ||
|
||
- [github](https://github.com/facebookresearch/segment-anything) | ||
|
||
## Demo | ||
- [Python](./python/) | ||
- [C++](./cpp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
project(sam) | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||
|
||
# include dir | ||
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) | ||
|
||
# libs dir | ||
link_directories(${CMAKE_CURRENT_LIST_DIR}/libs) | ||
|
||
# source files | ||
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/*.cpp) | ||
|
||
# target | ||
add_executable(sam_demo ${SRCS}) | ||
|
||
# link | ||
if (MSVC) | ||
target_link_libraries(sam_demo MNN) | ||
else() | ||
target_link_libraries(sam_demo MNN MNN_Express MNNOpenCV) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Usage | ||
|
||
## Compile MNN library | ||
### Linx/Mac | ||
```bash | ||
git clone https://github.com/alibaba/MNN.git | ||
# copy header file | ||
cp -r MNN/include . | ||
cp -r MNN/tools/cv/include . | ||
cd MNN | ||
mkdir build | ||
cmake -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. | ||
make -j8 | ||
cd .. | ||
cp MNN/build/libMNN.so MNN/build/express/libMNN_Express.so MNN/build/tools/cv/libMNNOpenCV.so ./libs | ||
``` | ||
|
||
### Windows | ||
```bash | ||
# Visual Studio xxxx Developer Command Prompt | ||
powershell | ||
git clone https://github.com/alibaba/MNN.git | ||
# copy header file | ||
cp -r MNN/include . | ||
cp -r MNN/tools/cv/include . | ||
cd MNN | ||
mkdir build | ||
cmake -G "Ninja" -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. | ||
ninja | ||
cd .. | ||
cp MNN.dll MNN.lib ./libs | ||
``` | ||
|
||
## Build and Run | ||
|
||
#### Linux/Mac | ||
```bash | ||
mkdir build && cd build | ||
cmake .. | ||
make -j4 | ||
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg | ||
``` | ||
#### Windows | ||
```bash | ||
# Visual Studio xxxx Developer Command Prompt | ||
powershell | ||
mkdir build && cd build | ||
cmake -G "Ninja" .. | ||
ninja | ||
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg | ||
``` |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#include <stdio.h> | ||
#include <MNN/ImageProcess.hpp> | ||
#include <MNN/expr/Module.hpp> | ||
#include <MNN/expr/Executor.hpp> | ||
#include <MNN/expr/ExprCreator.hpp> | ||
#include <MNN/expr/Executor.hpp> | ||
|
||
#include <cv/cv.hpp> | ||
|
||
using namespace MNN; | ||
using namespace MNN::Express; | ||
using namespace MNN::CV; | ||
|
||
int main(int argc, const char* argv[]) { | ||
if (argc < 4) { | ||
MNN_PRINT("Usage: ./sam_demo.out embed.mnn sam.mnn input.jpg [forwardType] [precision] [thread]\n"); | ||
return 0; | ||
} | ||
int thread = 4; | ||
int precision = 0; | ||
int forwardType = MNN_FORWARD_CPU; | ||
if (argc >= 5) { | ||
forwardType = atoi(argv[4]); | ||
} | ||
if (argc >= 6) { | ||
precision = atoi(argv[5]); | ||
} | ||
if (argc >= 7) { | ||
thread = atoi(argv[6]); | ||
} | ||
float mask_threshold = 0; | ||
MNN::ScheduleConfig sConfig; | ||
sConfig.type = static_cast<MNNForwardType>(forwardType); | ||
sConfig.numThread = thread; | ||
BackendConfig bConfig; | ||
bConfig.precision = static_cast<BackendConfig::PrecisionMode>(precision); | ||
sConfig.backendConfig = &bConfig; | ||
std::shared_ptr<Executor::RuntimeManager> rtmgr = std::shared_ptr<Executor::RuntimeManager>(Executor::RuntimeManager::createRuntimeManager(sConfig)); | ||
if(rtmgr == nullptr) { | ||
MNN_ERROR("Empty RuntimeManger\n"); | ||
return 0; | ||
} | ||
// rtmgr->setCache(".cachefile"); | ||
std::shared_ptr<Module> embed(Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr)); | ||
std::shared_ptr<Module> sam(Module::load( | ||
{"point_coords", "point_labels", "image_embeddings", "has_mask_input", "mask_input", "orig_im_size"}, | ||
{"iou_predictions", "low_res_masks", "masks"}, argv[2], rtmgr)); | ||
auto image = imread(argv[3]); | ||
// 1. preprocess | ||
auto dims = image->getInfo()->dim; | ||
int origin_h = dims[0]; | ||
int origin_w = dims[1]; | ||
int length = 1024; | ||
int new_h, new_w; | ||
if (origin_h > origin_w) { | ||
new_w = round(origin_w * (float)length / origin_h); | ||
new_h = length; | ||
} else { | ||
new_h = round(origin_h * (float)length / origin_w); | ||
new_w = length; | ||
} | ||
float scale_w = (float)new_w / origin_w; | ||
float scale_h = (float)new_h / origin_h; | ||
auto input_var = resize(image, Size(new_w, new_h), 0, 0, INTER_LINEAR, -1, {123.675, 116.28, 103.53}, {1/58.395, 1/57.12, 1/57.375}); | ||
std::vector<int> padvals { 0, length - new_h, 0, length - new_w, 0, 0 }; | ||
auto pads = _Const(static_cast<void*>(padvals.data()), {3, 2}, NCHW, halide_type_of<int>()); | ||
input_var = _Pad(input_var, pads, CONSTANT); | ||
input_var = _Unsqueeze(input_var, {0}); | ||
// 2. image embedding | ||
input_var = _Convert(input_var, NC4HW4); | ||
auto outputs = embed->onForward({input_var}); | ||
auto image_embedding = _Convert(outputs[0], NCHW); | ||
|
||
// 3. segment | ||
auto build_input = [](std::vector<float> data, std::vector<int> shape) { | ||
return _Const(static_cast<void*>(data.data()), shape, NCHW, halide_type_of<float>()); | ||
}; | ||
// build inputs | ||
std::vector<float> points = {500, 375}; | ||
auto scale_points = points; | ||
for (int i = 0; i < scale_points.size() / 2; i++) { | ||
scale_points[2 * i] = scale_points[2 * i] * scale_w; | ||
scale_points[2 * i + 1] = scale_points[2 * i + 1] * scale_h; | ||
} | ||
scale_points.push_back(0); | ||
scale_points.push_back(0); | ||
auto point_coords = build_input(scale_points, {1, 2, 2}); | ||
auto point_labels = build_input({1, -1}, {1, 2}); | ||
auto orig_im_size = build_input({static_cast<float>(origin_h), static_cast<float>(origin_w)}, {2}); | ||
auto has_mask_input = build_input({0}, {1}); | ||
std::vector<float> zeros(256*256, 0.f); | ||
auto mask_input = build_input(zeros, {1, 1, 256, 256}); | ||
auto output_vars = sam->onForward({point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size}); | ||
auto masks = _Convert(output_vars[2], NCHW); | ||
// 4. postprocess: draw mask and point | ||
masks = _Greater(masks, _Scalar(mask_threshold)); | ||
masks = _Reshape(masks, {origin_h, origin_w, 1}); | ||
std::vector<int> color_vec {30, 144, 255}; | ||
auto color = _Const(static_cast<void*>(color_vec.data()), {1, 1, 3}, NCHW, halide_type_of<int>()); | ||
image = _Cast<uint8_t>(_Cast<int>(image) + masks * color); | ||
auto ptr = image->readMap<uint8_t>(); | ||
for (int i = 0; i < points.size() / 2; i++) { | ||
float x = points[2 * i]; | ||
float y = points[2 * i + 1]; | ||
circle(image, {x, y}, 10, {0, 0, 255}, 5); | ||
} | ||
if (imwrite("res.jpg", image)) { | ||
MNN_PRINT("result image write to `res.jpg`.\n"); | ||
} | ||
// rtmgr->updateCache(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Usage | ||
|
||
## Install MNN | ||
``` | ||
pip install MNN | ||
``` | ||
|
||
## Run Demo | ||
``` | ||
python segment_anything_example.py --embed embed.mnn --sam segment.mnn --img ../resource/truck.jpg | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#-- coding:utf8 -- | ||
import argparse | ||
|
||
import MNN | ||
import MNN.numpy as np | ||
import MNN.cv as cv2 | ||
|
||
def inference(emed, sam, img, precision, backend, thread): | ||
mask_threshold = 0.0 | ||
config = {} | ||
config['precision'] = precision | ||
config['backend'] = backend | ||
config['numThread'] = thread | ||
rt = MNN.nn.create_runtime_manager((config,)) | ||
embed = MNN.nn.load_module_from_file(emed, [], [], runtime_manager=rt) | ||
sam = MNN.nn.load_module_from_file(sam, | ||
['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size'], | ||
['iou_predictions', 'low_res_masks', 'masks'], runtime_manager=rt) | ||
image = cv2.imread(img) | ||
origin_h, origin_w, _ = image.shape | ||
length = 1024 | ||
if origin_h > origin_w: | ||
new_w = round(origin_w * float(length) / origin_h) | ||
new_h = length | ||
else: | ||
new_h = round(origin_h * float(length) / origin_w) | ||
new_w = length | ||
scale_w = new_w / origin_w | ||
sclae_h = new_h / origin_h | ||
input_var = cv2.resize(image, (new_w, new_h), 0., 0., cv2.INTER_LINEAR, -1, [123.675, 116.28, 103.53], [1/58.395, 1/57.12, 1/57.375]) | ||
input_var = np.pad(input_var, [[0, length - new_h], [0, length - new_w], [0, 0]], 'constant') | ||
input_var = np.expand_dims(input_var, 0) | ||
input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4) | ||
output_var = embed.forward(input_var) | ||
image_embedding = MNN.expr.convert(output_var, MNN.expr.NCHW) | ||
points = [[500, 375]] | ||
sclaes = [scale_w, sclae_h] | ||
input_point = np.array(points) * sclaes | ||
input_label = np.array([1]) | ||
point_coords = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | ||
point_labels = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) | ||
orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32) | ||
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | ||
has_mask_input = np.zeros(1, dtype=np.float32) | ||
|
||
print('point_coords: ', point_coords, point_coords.shape) | ||
print('point_labels: ', point_labels, point_labels.shape) | ||
print('orig_im_size: ', orig_im_size, orig_im_size.shape) | ||
output_vars = sam.onForward([point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size]) | ||
masks = MNN.expr.convert(output_vars[2], MNN.expr.NCHW) | ||
masks = (masks > mask_threshold).reshape([origin_h, origin_w, 1]) | ||
# draw masks and point | ||
color = np.array([30, 144, 255]).reshape([1, 1, -1]) | ||
image = (image + masks * color).astype(np.uint8) | ||
for point in points: | ||
cv2.circle(image, point, 10, (0, 0, 255), 5) | ||
cv2.imwrite('res.jpg', image) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--embed', type=str, required=True, help='the embedding model path') | ||
parser.add_argument('--sam', type=str, required=True, help='the sam model path') | ||
parser.add_argument('--img', type=str, required=True, help='the input image path') | ||
parser.add_argument('--precision', type=str, default='normal', help='inference precision: normal, low, high, lowBF') | ||
parser.add_argument('--backend', type=str, default='CPU', help='inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI') | ||
parser.add_argument('--thread', type=int, default=4, help='inference using thread: int') | ||
args = parser.parse_args() | ||
inference(args.embed, args.sam, args.img, args.precision, args.backend, args.thread) |