Skip to content

Commit

Permalink
Merge pull request #5 from Techainer/fix_#4
Browse files Browse the repository at this point in the history
Fix #4: Allow to set maxAge and IoUThreshold
  • Loading branch information
lamhoangtung authored Apr 26, 2021
2 parents b8f09a2 + 90c8984 commit 8aa0ad2
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 18 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ yarn add sort-node@npm:@techainer1t/sort-node

The `sort-node` package contain the object `SortNode` that can be use to track object detected from a single video or camera.

The `SortNode` object can be initialize with 2 arguments:
- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID
The `SortNode` object can be initialize with 4 arguments in the following order:
- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID (should be 3)
- `kMaxAge`: (int) Maximum number of frames to keep alive a track without associated detections
- `kIoUThreshold`: (float between 0 and 1) Minimum IOU for match (should be 0.3)
- `kMinConfidence`: (float between 0 and 1) Bouding boxes with confidence score less than this value will be ignored

With each frame, you will need to call `update` method.
Expand All @@ -53,8 +55,10 @@ Please noted that the number of returned object might not be the same as the num
```javascript
const sortnode = require("@techainer1t/sort-node");
const kMinHits = 3;
const kMaxAge = 1;
const kIoUThreshold = 0.3;
const kMinConfidence = 0.3;
const tracker = sortnode.SortNode(kMinHists, kMinConfidence);
const tracker = sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence);
while (true){
// Call the object detector
...
Expand Down
2 changes: 1 addition & 1 deletion include/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Tracker {
std::vector<std::pair<cv::Rect, std::vector<float>>> &unmatched_det,
float iou_threshold = 0.3);

void Run(const std::vector<std::pair<cv::Rect, std::vector<float>>> &detections);
void Run(const std::vector<std::pair<cv::Rect, std::vector<float>>> &detections, int kMaxAge, float kIoUThreshold);

std::map<int, Track> GetTracks();

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"name": "@techainer1t/sort-node",
"description": "Node binding of SORT: Simple, online, and real-time tracking of multiple objects in a video sequence.",
"version": "1.1.0",
"version": "1.1.1",
"directories": {
"doc": "docs"
},
Expand Down
38 changes: 31 additions & 7 deletions src/sort_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ namespace sortnode
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (info.Length() < 2 || info.Length() > 2)
if (info.Length() < 4 || info.Length() > 4)
{
Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments: kMinHits, kMinConfidence")
Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments, expect: kMinHits, kMaxAge, kIoUThreshold, kMinConfidence")
.ThrowAsJavaScriptException();
return;
}
Expand All @@ -37,22 +37,46 @@ namespace sortnode
}

auto kMinHits = info[0].As<Napi::Number>().DoubleValue();
if (fmod(kMinHits, 1) != 0)
if (fmod(kMinHits, 1) != 0 || kMinHits < 0)
{
Napi::TypeError::New(env, "kMinHits must be an interger")
Napi::TypeError::New(env, "kMinHits must be an interger greater than 0")
.ThrowAsJavaScriptException();
return;
}

if (!info[1].IsNumber())
{
Napi::TypeError::New(env, "kMaxAge must be an interger")
.ThrowAsJavaScriptException();
return;
}

auto kMaxAge = info[1].As<Napi::Number>().DoubleValue();
if (fmod(kMaxAge, 1) != 0 || kMaxAge < 0)
{
Napi::TypeError::New(env, "kMaxAge must be an interger greater than 0")
.ThrowAsJavaScriptException();
return;
}

if (!info[2].IsNumber() || info[2].As<Napi::Number>().DoubleValue() > 1 || info[2].As<Napi::Number>().DoubleValue() < 0)
{
Napi::TypeError::New(env, "kIoUThreshold must be a float between 0 and 1")
.ThrowAsJavaScriptException();
return;
}

if (!info[1].IsNumber() || info[1].As<Napi::Number>().DoubleValue() > 1 || info[1].As<Napi::Number>().DoubleValue() < 0)
if (!info[3].IsNumber() || info[3].As<Napi::Number>().DoubleValue() > 1 || info[3].As<Napi::Number>().DoubleValue() < 0)
{
Napi::TypeError::New(env, "kMinConfidence must be a float between 0 and 1")
.ThrowAsJavaScriptException();
return;
}

this->kMinHits = int(kMinHits);
this->kMinConfidence = float(info[1].As<Napi::Number>().DoubleValue());
this->kMaxAge = int(kMaxAge);
this->kIoUThreshold = float(info[2].As<Napi::Number>().DoubleValue());
this->kMinConfidence = float(info[3].As<Napi::Number>().DoubleValue());
}

Napi::Value SortNode::update(const Napi::CallbackInfo& info)
Expand Down Expand Up @@ -136,7 +160,7 @@ namespace sortnode
}

// Run SORT tracker
this->tracker.Run(bbox_per_frame);
this->tracker.Run(bbox_per_frame, this->kMaxAge, this->kIoUThreshold);
const auto tracks = this->tracker.GetTracks();

// Convert results from cv::Rect to normal float vector
Expand Down
2 changes: 2 additions & 0 deletions src/sort_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace sortnode
{
public:
int kMinHits = 3;
int kMaxAge = 1;
int kMaxCoastCycles = 1;
float kIoUThreshold = 0.3;
float kMinConfidence = 0.6;
int frame_index = 0;
Tracker tracker;
Expand Down
6 changes: 3 additions & 3 deletions src/tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void Tracker::AssociateDetectionsToTrackers(const std::vector<std::pair<cv::Rect
}


void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& detections) {
void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& detections, int kMaxAge, float kIoUThreshold) {

/*** Predict internal tracks from previous frame ***/
for (auto &track : tracks_) {
Expand All @@ -149,7 +149,7 @@ void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& de

// return values - matched, unmatched_det
if (!detections.empty()) {
AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det);
AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det, kIoUThreshold);
}

/*** Update tracks with associated bbox ***/
Expand All @@ -168,7 +168,7 @@ void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& de

/*** Delete lose tracked tracks ***/
for (auto it = tracks_.begin(); it != tracks_.end();) {
if (it->second.coast_cycles_ > kMaxCoastCycles) {
if (it->second.coast_cycles_ > kMaxAge) {
it = tracks_.erase(it);
} else {
it++;
Expand Down
8 changes: 5 additions & 3 deletions test/test_binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ assert(sortnode.SortNode, "The expected module is undefined");
function testBasic() {
console.log("Running testBasic");
const kMinHits = 3;
const kMaxAge = 1;
const kMinConfidence = 0.3;
const instance = new sortnode.SortNode(kMinHits, kMinConfidence);
const kIoUThreshold = 0.3;
const instance = new sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence);
assert(instance.update, "The expected method is not defined");
}

Expand Down Expand Up @@ -63,7 +65,7 @@ function testAccuracyWithoutLandmark() {

const total_frames = all_detections.length;

const tracker = new sortnode.SortNode(3, 0.6);
const tracker = new sortnode.SortNode(3, 1, 0.3, 0.6);
let frame_index = 0
let predicted = [];
const t1 = Date.now()
Expand Down Expand Up @@ -98,7 +100,7 @@ function testAccuracyWithoutLandmark() {

function testKeepLandmark(){
console.log("Running testKeepLandmark")
const tracker = new sortnode.SortNode(3, 0.3);
const tracker = new sortnode.SortNode(3, 1, 0.3, 0);

let input = [
[120, 240, 50, 70, 0.9, 23, 24, 25, 26, 27, 28, 29, 30],
Expand Down

0 comments on commit 8aa0ad2

Please sign in to comment.