diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 28bce6f..c35fadf 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -11,193 +11,73 @@ env: jobs: - build-on-linux: - name: build / linux / ffmpeg ${{ matrix.ffmpeg_version }} - runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:${{ matrix.ffmpeg_version }}-ubuntu - + check: + name: Check + runs-on: ${{ matrix.os }} strategy: matrix: - ffmpeg_version: ["4.3", "4.4", "5.0", "5.1", "6.0", "6.1", "7.0"] - fail-fast: false - + os: [ubuntu-latest, macOS-latest, windows-latest] + rust: [stable] steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: - toolchain: stable - - - name: Build - run: cargo build - - build-on-macos: - name: build / macos / ffmpeg latest - runs-on: macos-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - brew install ffmpeg pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 with: - toolchain: stable - - - name: Build - run: cargo build - - - build-on-windows: - name: build / windows / ffmpeg latest - runs-on: windows-latest - - env: - FFMPEG_DOWNLOAD_URL: https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-full-shared.7z - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - $VCINSTALLDIR = $(& "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath) - Add-Content $env:GITHUB_ENV "LIBCLANG_PATH=${VCINSTALLDIR}\VC\Tools\LLVM\x64\bin`n" - Invoke-WebRequest "${env:FFMPEG_DOWNLOAD_URL}" -OutFile ffmpeg-release-full-shared.7z - 7z x ffmpeg-release-full-shared.7z - mkdir ffmpeg - mv ffmpeg-*/* ffmpeg/ - Add-Content $env:GITHUB_ENV "FFMPEG_DIR=${pwd}\ffmpeg`n" - Add-Content $env:GITHUB_PATH "${pwd}\ffmpeg\bin`n" - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 - with: - toolchain: stable - - - name: Build - run: cargo build - - - test-on-linux: - name: test / linux / ffmpeg ${{ matrix.ffmpeg_version }} - runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:${{ matrix.ffmpeg_version }}-ubuntu + command: check + args: --all + test: + name: Test + runs-on: ${{ matrix.os }} strategy: matrix: - ffmpeg_version: ["4.3", "4.4", "5.0", "5.1", "6.0", "6.1", "7.0"] - fail-fast: false - + os: [ubuntu-latest, macOS-latest, windows-latest] + rust: [stable] steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: - toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - - test-on-macos: - name: test / macos / ffmpeg latest - runs-on: macos-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - brew install ffmpeg pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 with: - toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - - test-on-windows: - name: test / windows / ffmpeg latest - runs-on: windows-latest - - env: - FFMPEG_DOWNLOAD_URL: https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-full-shared.7z + command: test + args: --all + fmt: + name: Rustfmt + runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - $VCINSTALLDIR = $(& "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath) - Add-Content $env:GITHUB_ENV "LIBCLANG_PATH=${VCINSTALLDIR}\VC\Tools\LLVM\x64\bin`n" - Invoke-WebRequest "${env:FFMPEG_DOWNLOAD_URL}" -OutFile ffmpeg-release-full-shared.7z - 7z x ffmpeg-release-full-shared.7z - mkdir ffmpeg - mv ffmpeg-*/* ffmpeg/ - Add-Content $env:GITHUB_ENV "FFMPEG_DIR=${pwd}\ffmpeg`n" - Add-Content $env:GITHUB_PATH "${pwd}\ffmpeg\bin`n" - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: + profile: minimal toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - + override: true + - run: rustup component add rustfmt + - uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check - lints: + clippy: + name: Clippy runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:6-ubuntu - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: + profile: minimal toolchain: stable - components: rustfmt, clippy - - - name: Rustfmt - run: cargo fmt --all -- --check + override: true + - run: rustup component add clippy + - uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all --all-targets -- -D warnings - - name: Clippy - run: cargo clippy --all --all-targets --all-features -- -D warnings diff --git a/.gitignore b/.gitignore index b99985e..e1a526e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ debug/ target/ +**/*.DS_Store + # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock @@ -13,7 +15,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb - .debug .vscode runs/ diff --git a/Cargo.toml b/Cargo.toml index c6c15d7..9e9179b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,64 +1,66 @@ [package] name = "usls" -version = "0.0.20" +version = "0.1.0" +rust-version = "1.79" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" authors = ["Jamjamjon "] license = "MIT" readme = "README.md" -exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"] +exclude = ["assets/*", "examples/*", "runs/*", "benches/*"] [dependencies] -clap = { version = "4.2.4", features = ["derive"] } +aksr = { version = "0.0.2" } +image = { version = "0.25.2" } +imageproc = { version = "0.24" } ndarray = { version = "0.16.1", features = ["rayon"] } -ort = { version = "2.0.0-rc.9", default-features = false } +rayon = { version = "1.10.0" } anyhow = { version = "1.0.75" } regex = { version = "1.5.4" } rand = { version = "0.8.5" } chrono = { version = "0.4.30" } -half = { version = "2.3.1" } -dirs = { version = "5.0.1" } -ureq = { version = "2.9.1", default-features = true, features = [ - "socks-proxy", -] } tokenizers = { version = "0.15.2" } -rayon = "1.10.0" +log = { version = "0.4.22" } indicatif = "0.17.8" -image = "0.25.2" -imageproc = { version = "0.24" } -ab_glyph = "0.2.23" -geo = "0.28.0" -prost = "0.12.4" -fast_image_resize = { version = "4.2.1", features = ["image"] } -serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } +ort = { version = "2.0.0-rc.9", default-features = false} +prost = "0.12.4" +ab_glyph = "0.2.23" +dirs = { version = "5.0.1" } tempfile = "3.12.0" -video-rs = { version = "0.9.0", features = ["ndarray"] } +geo = "0.28.0" +half = { version = "2.3.1" } +ureq = { version = "2.12.1", default-features = false, features = [ "tls" ] } +fast_image_resize = { version = "4.2.1", features = ["image"]} natord = "1.0.9" -tracing = "0.1.40" -tracing-subscriber = "0.3.18" -minifb = "0.27.0" +video-rs = { version = "0.10.0", features = ["ndarray"], optional = true } +minifb = { version = "0.27.0", optional = true } +sha2 = "0.10.8" +[dev-dependencies] +argh = "0.1.13" +tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] } + +[[example]] +name = "viewer" +required-features = ["ffmpeg"] [features] default = [ - "ort/load-dynamic", - "ort/copy-dylibs", - "ort/half", - "ort/ndarray", - "ort/cuda", - "ort/tensorrt", - "ort/coreml", + "ort/ndarray", + "ort/copy-dylibs", + "ort/load-dynamic", + "ort/half", ] auto = ["ort/download-binaries"] +ffmpeg = ["dep:video-rs", "dep:minifb"] +cuda = [ "ort/cuda" ] +trt = [ "ort/tensorrt" ] +mps = [ "ort/coreml" ] -[dev-dependencies] -criterion = "0.5.1" - -[[bench]] -name = "yolo" -harness = false - -[lib] -bench = false +[profile.release] +# lto = true +strip = true +panic = "abort" diff --git a/README.md b/README.md index 7211724..fb953e7 100644 --- a/README.md +++ b/README.md @@ -1,221 +1,161 @@ -

-

usls

-

+

usls

- Documentation -
-
+ + Rust Continuous Integration Badge + + + usls Version + + + Rust MSRV + - ONNXRuntime Release Page + ONNXRuntime MSRV - CUDA Toolkit Page + CUDA MSRV - TensorRT Page + TensorRT MSRV + + + Crates.io Total Downloads

-

- - Crates Page - - - - - - Crates.io Total Downloads - - + + Examples + + + usls documentation +

-**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including: +**usls** is a Rust library integrated with **ONNXRuntime**, offering a suite of advanced models for **Computer Vision** and **Vision-Language** tasks, including: -- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLOv11](https://github.com/ultralytics/ultralytics) +- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics) - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) -- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro) -- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) +- **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone) +- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) +- **OCR Models**: [FAST](https://github.com/czczup/FAST), [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947), [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
-Click to expand Supported Models - -## Supported Models - -| Model | Task / Type | Example | CUDA f32 | CUDA f16 | TensorRT f32 | TensorRT f16 | -|---------------------------------------------------------------------|----------------------------------------------------------------------------------------------|----------------------------|----------|----------|--------------|--------------| -| [YOLOv5](https://github.com/ultralytics/yolov5) | Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [YOLOv11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | -| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | -| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | -| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | -| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | -| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | -| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | -| [CLIP](https://github.com/openai/CLIP) | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ Visual
❌ Textual | ✅ Visual
❌ Textual | -| [BLIP](https://github.com/salesforce/BLIP) | Vision-Language | [demo](examples/blip) | ✅ | ✅ | ✅ Visual
❌ Textual | ✅ Visual
❌ Textual | -| [DB](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | ✅ | ✅ | ✅ | -| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ | -| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ | -| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ | -| [Depth-Anything v1 & v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ | -| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | -| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | | -| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | | -| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | | -| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | | | +👉 More Supported Models + +| Model | Task / Description | Example | CoreML | CUDA
FP32 | CUDA
FP16 | TensorRT
FP32 | TensorRT
FP16 | +| -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | ------ | -------------- | -------------- | ------------------ | ------------------ | +| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) | ✅ | ✅ | ✅ | | | +| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) | ✅ | ✅ | ✅ | | | +| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) | ✅ | ✅ | ✅ | | | +| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) | ✅ | ✅ | ✅ | | | +| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) | ✅ | ✅ | ✅ | | | +| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv8
YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Image Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | | | +| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) | ✅ | ✅ | ✅ | | | +| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) | ✅ | ✅ | ✅ | | | +| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) | ✅ | ✅ | ✅ | | | +| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) | ✅ | ✅ | ✅ | | | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ✅ | ❌ | ❌ | +| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | | +| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | | +| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | | +| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | | +| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | | +| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | ✅ | | | +| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ | +| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ | +| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) | ✅ | ✅ | ✅ | ❌ | ❌ | +| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) | ✅ | ✅ | ✅ | | | +| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) | ✅ | ✅ | ✅ | | | +| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DepthAnything v1
DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ✅ | ❌ | ❌ | +| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | ✅ | | | +| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | ✅ | ✅ | ✅ | | | +| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | | + +
+## ⛳️ Cargo Features +By default, **none of the following features are enabled**. You can enable them as needed: - +- **`auto`**: Automatically downloads prebuilt ONNXRuntime binaries from Pyke’s CDN for supported platforms. + - If disabled, you'll need to [compile `ONNXRuntime` from source](https://github.com/microsoft/onnxruntime) or [download a precompiled package](https://github.com/microsoft/onnxruntime/releases), and then [link it manually](https://ort.pyke.io/setup/linking). -## ⛳️ ONNXRuntime Linking +
+ 👉 For Linux or macOS Users + + - Download from the [Releases page](https://github.com/microsoft/onnxruntime/releases). + - Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable: + ```shell + export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.20.1 + ``` + +
+- **`ffmpeg`**: Adds support for video streams, real-time frame visualization, and video export. + + - Powered by [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb). For any issues related to `ffmpeg` features, please refer to the issues of these two crates. +- **`cuda`**: Enables the NVIDIA TensorRT provider. +- **`trt`**: Enables the NVIDIA TensorRT provider. +- **`mps`**: Enables the Apple CoreML provider. + +## 🎈 Example + +* **Using `CUDA`** -
-You have two options to link the ONNXRuntime library - -- ### Option 1: Manual Linking - - - #### For detailed setup instructions, refer to the [ORT documentation](https://ort.pyke.io/setup/linking). - - - #### For Linux or macOS Users: - - Download the ONNX Runtime package from the [Releases page](https://github.com/microsoft/onnxruntime/releases). - - Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable: - ```shell - export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.19.0 - ``` - -- ### Option 2: Automatic Download - Just use `--features auto` - ```shell - cargo run -r --example yolo --features auto ``` + cargo run -r -F cuda --example yolo -- --device cuda:0 + ``` +* **Using Apple `CoreML`** -
+ ``` + cargo run -r -F mps --example yolo -- --device mps + ``` +* **Using `TensorRT`** + + ``` + cargo run -r -F trt --example yolo -- --device trt + ``` +* **Using `CPU`** + + ``` + cargo run -r --example yolo + ``` -## 🎈 Demo +All examples are located in the [examples](./examples/) directory. + +## 🥂 Integrate Into Your Own Project + +Add `usls` as a dependency to your project's `Cargo.toml` ```Shell -cargo run -r --example yolo # blip, clip, yolop, svtr, db, ... +cargo add usls -F cuda ``` -## 🥂 Integrate Into Your Own Project +Or use a specific commit: -- #### Add `usls` as a dependency to your project's `Cargo.toml` - ```Shell - cargo add usls - ``` - - Or use a specific commit: - ```Toml - [dependencies] - usls = { git = "https://github.com/jamjamjon/usls", rev = "commit-sha" } - ``` - -- #### Follow the pipeline - - Build model with the provided `models` and `Options` - - Load images, video and stream with `DataLoader` - - Do inference - - Retrieve inference results from `Vec` - - Annotate inference results with `Annotator` - - Display images and write them to video with `Viewer` - -
-
- example code - - ```rust - use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion}; - - fn main() -> anyhow::Result<()> { - // Build model with Options - let options = Options::new() - .with_trt(0) - .with_model("yolo/v8-m-dyn.onnx")? - .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_ixx(0, 0, (1, 2, 4).into()) - .with_ixx(0, 2, (0, 640, 640).into()) - .with_ixx(0, 3, (0, 640, 640).into()) - .with_confs(&[0.2]); - let mut model = YOLO::new(options)?; - - // Build DataLoader to load image(s), video, stream - let dl = DataLoader::new( - // "./assets/bus.jpg", // local image - // "images/bus.jpg", // remote image - // "../images-folder", // local images (from folder) - // "../demo.mp4", // local video - // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video - "rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream - )? - .with_batch(2) // iterate with batch_size = 2 - .build()?; - - // Build annotator - let annotator = Annotator::new() - .with_bboxes_thickness(4) - .with_saveout("YOLO-DataLoader"); - - // Build viewer - let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true); - - // Run and annotate results - for (xs, _) in dl { - let ys = model.forward(&xs, false)?; - // annotator.annotate(&xs, &ys); - let images_plotted = annotator.plot(&xs, &ys, false)?; - - // show image - viewer.imshow(&images_plotted)?; - - // check out window and key event - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { - break; - } - - // write video - viewer.write_batch(&images_plotted)?; - - // Retrieve inference results - for y in ys { - // bboxes - if let Some(bboxes) = y.bboxes() { - for bbox in bboxes { - println!( - "Bbox: {}, {}, {}, {}, {}, {}", - bbox.xmin(), - bbox.ymin(), - bbox.xmax(), - bbox.ymax(), - bbox.confidence(), - bbox.id(), - ); - } - } - } - } - - // finish video write - viewer.finish_write()?; - - Ok(()) - } - ``` - -
-
+```Toml +[dependencies] +usls = { git = "https://github.com/jamjamjon/usls", rev = "commit-sha" } +``` + +## 🥳 If you find this helpful, please give it a star ⭐ ## 📌 License + This project is licensed under [LICENSE](LICENSE). diff --git a/benches/yolo.rs b/benches/yolo.rs deleted file mode 100644 index 9ee3196..0000000 --- a/benches/yolo.rs +++ /dev/null @@ -1,94 +0,0 @@ -use anyhow::Result; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; - -use usls::{models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; - -enum Stage { - Pre, - Run, - Post, - Pipeline, -} - -fn yolo_stage_bench( - model: &mut YOLO, - x: &[image::DynamicImage], - stage: Stage, - n: u64, -) -> std::time::Duration { - let mut t_pre = std::time::Duration::new(0, 0); - let mut t_run = std::time::Duration::new(0, 0); - let mut t_post = std::time::Duration::new(0, 0); - let mut t_pipeline = std::time::Duration::new(0, 0); - for _ in 0..n { - let t0 = std::time::Instant::now(); - let xs = model.preprocess(x).unwrap(); - t_pre += t0.elapsed(); - - let t = std::time::Instant::now(); - let xs = model.inference(xs).unwrap(); - t_run += t.elapsed(); - - let t = std::time::Instant::now(); - let _ys = black_box(model.postprocess(xs, x).unwrap()); - t_post += t.elapsed(); - t_pipeline += t0.elapsed(); - } - match stage { - Stage::Pre => t_pre, - Stage::Run => t_run, - Stage::Post => t_post, - Stage::Pipeline => t_pipeline, - } -} - -pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> { - let mut group = c.benchmark_group(format!("YOLO ({}-{})", w, h)); - group - .significance_level(0.05) - .sample_size(80) - .measurement_time(std::time::Duration::new(20, 0)); - - let options = Options::default() - .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_model("yolo/v8-m-dyn.onnx")? - .with_cuda(0) - // .with_cpu() - .with_dry_run(0) - .with_ixx(0, 2, (320, h, 1280).into()) - .with_ixx(0, 3, (320, w, 1280).into()) - .with_confs(&[0.2, 0.15]); - let mut model = YOLO::new(options)?; - - let xs = [DataLoader::try_read("./assets/bus.jpg")?]; - - group.bench_function("pre-process", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pre, n)) - }); - - group.bench_function("run", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Run, n)) - }); - - group.bench_function("post-process", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Post, n)) - }); - - group.bench_function("pipeline", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pipeline, n)) - }); - - group.finish(); - Ok(()) -} - -pub fn criterion_benchmark(c: &mut Criterion) { - // benchmark_cuda(c, 416, 416).unwrap(); - benchmark_cuda(c, 640, 640).unwrap(); - benchmark_cuda(c, 448, 768).unwrap(); - // benchmark_cuda(c, 800, 800).unwrap(); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/examples/beit/README.md b/examples/beit/README.md new file mode 100644 index 0000000..d9eddd8 --- /dev/null +++ b/examples/beit/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example beit -- --device cuda --dtype fp16 +``` + diff --git a/examples/beit/main.rs b/examples/beit/main.rs new file mode 100644 index 0000000..aad67bd --- /dev/null +++ b/examples/beit/main.rs @@ -0,0 +1,52 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::beit_base() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/blip/README.md b/examples/blip/README.md index e0dfe3e..6121661 100644 --- a/examples/blip/README.md +++ b/examples/blip/README.md @@ -3,20 +3,12 @@ This demo shows how to use [BLIP](https://arxiv.org/abs/2201.12086) to do condit ## Quick Start ```shell -cargo run -r --example blip +cargo run -r -F cuda --example blip -- --device cuda:0 --source images/dog.jpg --source ./assets/bus.jpg --source images/green-car.jpg ``` ## Results ```shell -[Unconditional]: a group of people walking around a bus -[Conditional]: three man walking in front of a bus -Some(["three man walking in front of a bus"]) +Unconditional: Ys([Y { Texts: [Text("a dog running through a field of grass")] }, Y { Texts: [Text("a group of people walking around a bus")] }, Y { Texts: [Text("a green volkswagen beetle parked in front of a yellow building")] }]) +Conditional: Ys([Y { Texts: [Text("this image depicting a dog running in a field")] }, Y { Texts: [Text("this image depict a bus in barcelona")] }, Y { Texts: [Text("this image depict a blue volkswagen beetle parked in a street in havana, cuba")] }]) ``` - -## TODO - -* [ ] Multi-batch inference for image caption -* [ ] VQA -* [ ] Retrival -* [ ] TensorRT support for textual model diff --git a/examples/blip/main.rs b/examples/blip/main.rs index da7fc89..d5e3e21 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -1,28 +1,44 @@ -use usls::{models::Blip, DataLoader, Options}; - -fn main() -> Result<(), Box> { - // visual - let options_visual = Options::default() - .with_model("blip/visual-base.onnx")? - // .with_ixx(0, 2, 384.into()) - // .with_ixx(0, 3, 384.into()) - .with_profile(false); - - // textual - let options_textual = Options::default() - .with_model("blip/textual-base.onnx")? - .with_tokenizer("blip/tokenizer.json")? - .with_profile(false); - - // build model - let mut model = Blip::new(options_visual, options_textual)?; - - // image caption (this demo use batch_size=1) - let xs = [DataLoader::try_read("images/bus.jpg")?]; - let image_embeddings = model.encode_images(&xs)?; - let _y = model.caption(&image_embeddings, None, true)?; // unconditional - let y = model.caption(&image_embeddings, Some("three man"), true)?; // conditional - println!("{:?}", y[0].texts()); - - Ok(()) -} +use usls::{models::Blip, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// BLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options_visual = Options::blip_v1_base_caption_visual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_textual = Options::blip_v1_base_caption_textual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = Blip::new(options_visual, options_textual)?; + + // image caption + let xs = DataLoader::try_read_batch(&args.source)?; + + // unconditional caption + let ys = model.forward(&xs, None)?; + println!("Unconditional: {:?}", ys); + + // conditional caption + let ys = model.forward(&xs, Some("this image depict"))?; + println!("Conditional: {:?}", ys); + + Ok(()) +} diff --git a/examples/clip/README.md b/examples/clip/README.md index d85a682..71fe94e 100644 --- a/examples/clip/README.md +++ b/examples/clip/README.md @@ -3,18 +3,13 @@ This demo showcases how to use [CLIP](https://github.com/openai/CLIP) to compute ## Quick Start ```shell -cargo run -r --example clip +cargo run -r -F cuda --example clip -- --device cuda:0 ``` ## Results ```shell -(90.11472%) ./examples/clip/images/carrot.jpg => 几个胡萝卜 -[0.04573484, 0.0048218793, 0.0011618224, 0.90114725, 0.0036694852, 0.031348046, 0.0121166315] - -(94.07785%) ./examples/clip/images/peoples.jpg => Some people holding wine glasses in a restaurant -[0.050406333, 0.0011632168, 0.0019338318, 0.0013227565, 0.003916758, 0.00047858112, 0.9407785] - -(86.59852%) ./examples/clip/images/doll.jpg => There is a doll with red hair and a clock on a table -[0.07032883, 0.00053773675, 0.0006372929, 0.06066096, 0.0007378078, 0.8659852, 0.0011121632] -``` \ No newline at end of file +(99.9675%) ./examples/clip/images/carrot.jpg => Some carrots +(99.93718%) ./examples/clip/images/doll.jpg => There is a doll with red hair and a clock on a table +(100.0%) ./examples/clip/images/drink.jpg => Some people holding wine glasses in a restaurant +``` diff --git a/examples/clip/images/peoples.jpg b/examples/clip/images/drink.jpg similarity index 100% rename from examples/clip/images/peoples.jpg rename to examples/clip/images/drink.jpg diff --git a/examples/clip/main.rs b/examples/clip/main.rs index 0fd03ce..e213c31 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -1,43 +1,54 @@ -use usls::{models::Clip, DataLoader, Options}; +use anyhow::Result; +use usls::{models::Clip, DataLoader, Ops, Options}; -fn main() -> Result<(), Box> { - // visual - let options_visual = Options::default().with_model("clip/visual-base-dyn.onnx")?; +#[derive(argh::FromArgs)] +/// CLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} - // textual - let options_textual = Options::default() - .with_model("clip/textual-base-dyn.onnx")? - .with_tokenizer("clip/tokenizer.json")?; +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + let args: Args = argh::from_env(); // build model + let options_visual = Options::jina_clip_v1_visual() + // clip_vit_b32_visual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_textual = Options::jina_clip_v1_textual() + // clip_vit_b32_textual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = Clip::new(options_visual, options_textual)?; // texts let texts = vec![ - "A photo of a dinosaur ".to_string(), - "A photo of a cat".to_string(), - "A photo of a dog".to_string(), - "几个胡萝卜".to_string(), - "There are some playing cards on a striped table cloth".to_string(), - "There is a doll with red hair and a clock on a table".to_string(), - "Some people holding wine glasses in a restaurant".to_string(), + "A photo of a dinosaur", + "A photo of a cat", + "A photo of a dog", + "Some carrots", + "There are some playing cards on a striped table cloth", + "There is a doll with red hair and a clock on a table", + "Some people holding wine glasses in a restaurant", ]; let feats_text = model.encode_texts(&texts)?; // [n, ndim] - // load image + // load images let dl = DataLoader::new("./examples/clip/images")?.build()?; - // loop + // run for (images, paths) in dl { - let feats_image = model.encode_images(&images).unwrap(); + let feats_image = model.encode_images(&images)?; // use image to query texts - let matrix = match feats_image.embedding() { - Some(x) => x.dot2(feats_text.embedding().unwrap())?, - None => continue, - }; + let matrix = Ops::dot2(&feats_image, &feats_text)?; - // summary for i in 0..paths.len() { let probs = &matrix[i]; let (id, &score) = probs @@ -52,7 +63,6 @@ fn main() -> Result<(), Box> { paths[i].display(), &texts[id] ); - println!("{:?}\n", probs); } } diff --git a/examples/convnext/README.md b/examples/convnext/README.md new file mode 100644 index 0000000..fe6d945 --- /dev/null +++ b/examples/convnext/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example convnext -- --device cuda --dtype fp16 +``` + diff --git a/examples/convnext/main.rs b/examples/convnext/main.rs new file mode 100644 index 0000000..6480a07 --- /dev/null +++ b/examples/convnext/main.rs @@ -0,0 +1,52 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::convnext_v2_atto() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/d-fine/README.md b/examples/d-fine/README.md new file mode 100644 index 0000000..61eb5ba --- /dev/null +++ b/examples/d-fine/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +cargo run -r --example d-fine +``` diff --git a/examples/d-fine/main.rs b/examples/d-fine/main.rs new file mode 100644 index 0000000..2726232 --- /dev/null +++ b/examples/d-fine/main.rs @@ -0,0 +1,28 @@ +use anyhow::Result; +use usls::{models::RTDETR, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // options + let options = Options::d_fine_n_coco().commit()?; + let mut model = RTDETR::new(options)?; + + // load + let x = [DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let y = model.forward(&x)?; + println!("{:?}", y); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/dataloader/README.md b/examples/dataloader/README.md new file mode 100644 index 0000000..29d81b9 --- /dev/null +++ b/examples/dataloader/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +cargo run -r --example dataloader +``` diff --git a/examples/dataloader/main.rs b/examples/dataloader/main.rs index dbc27fe..eacb3b4 100644 --- a/examples/dataloader/main.rs +++ b/examples/dataloader/main.rs @@ -1,66 +1,45 @@ -use usls::{ - models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOTask, YOLOVersion, -}; +use usls::DataLoader; fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() - .with_max_level(tracing::Level::ERROR) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - let options = Options::new() - .with_device(Device::Cuda(0)) - .with_model("yolo/v8-m-dyn.onnx")? - .with_yolo_version(YOLOVersion::V8) - .with_yolo_task(YOLOTask::Detect) - .with_batch(2) - .with_ixx(0, 2, (416, 640, 800).into()) - .with_ixx(0, 3, (416, 640, 800).into()) - .with_confs(&[0.2]); - let mut model = YOLO::new(options)?; - - // build annotator - let annotator = Annotator::new() - .with_bboxes_thickness(4) - .with_saveout("YOLO-DataLoader"); - - // build dataloader - let dl = DataLoader::new( + // 1. iterator + let dl = DataLoader::try_from( // "images/bus.jpg", // remote image // "../images", // image folder // "../demo.mp4", // local video // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream - // "./assets/bus.jpg", // local image - "../7.mp4", + "./assets/bus.jpg", // local image )? .with_batch(1) + .with_progress_bar(true) .build()?; - let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true); - - // iteration - for (xs, _) in dl { - // inference & annotate - let ys = model.run(&xs)?; - let images_plotted = annotator.plot(&xs, &ys, false)?; - - // show image - viewer.imshow(&images_plotted)?; - - // check out window and key event - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { - break; - } - - // write video - viewer.write_batch(&images_plotted)?; + for (_xs, _paths) in dl { + println!("Paths: {:?}", _paths); } - // finish video write - viewer.finish_write()?; - - // images -> video - // DataLoader::is2v("runs/YOLO-DataLoader", &["runs", "is2v"], 24)?; + // 2. read one image + let image = DataLoader::try_read("./assets/bus.jpg")?; + println!( + "Read one image. Height: {}, Width: {}", + image.height(), + image.width() + ); + + // 3. read several images + let images = DataLoader::try_read_batch(&[ + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + ])?; + println!("Read {} images.", images.len()); Ok(()) } diff --git a/examples/db/README.md b/examples/db/README.md index 6da1cfc..9e19375 100644 --- a/examples/db/README.md +++ b/examples/db/README.md @@ -4,15 +4,6 @@ cargo run -r --example db ``` -### Speed test - -| Model | Image size | TensorRT
f16
batch=1
(ms) | TensorRT
f32
batch=1
(ms) | CUDA
f32
batch=1
(ms) | -| --------------- | ---------- | ---------------------------------------- | ---------------------------------------- | ------------------------------------ | -| ppocr-v3-db-dyn | 640x640 | 1.8585 | 2.5739 | 4.3314 | -| ppocr-v4-db-dyn | 640x640 | 2.0507 | 2.8264 | 6.6064 | - -***Test on RTX3060*** - ## Results ![](https://github.com/jamjamjon/assets/releases/download/db/demo-paper.png) diff --git a/examples/db/main.rs b/examples/db/main.rs index b133216..13bdb87 100644 --- a/examples/db/main.rs +++ b/examples/db/main.rs @@ -1,35 +1,48 @@ +use anyhow::Result; use usls::{models::DB, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // build model - let options = Options::default() - .with_ixx(0, 0, (1, 4, 8).into()) - .with_ixx(0, 2, (608, 960, 1280).into()) - .with_ixx(0, 3, (608, 960, 1280).into()) - // .with_trt(0) - .with_confs(&[0.4]) - .with_min_width(5.0) - .with_min_height(12.0) - .with_model("db/ppocr-v4-db-dyn.onnx")?; +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + // build model + let options = Options::ppocr_det_v4_server_ch() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = DB::new(options)?; // load image - let x = [ - DataLoader::try_read("images/db.png")?, - DataLoader::try_read("images/street.jpg")?, - ]; + let x = DataLoader::try_read_batch(&[ + "images/table.png", + "images/table1.jpg", + "images/table2.png", + "images/table-ch.jpg", + "images/db.png", + "images/street.jpg", + ])?; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .without_bboxes(true) + .without_mbrs(true) .with_polygons_alpha(60) .with_contours_color([255, 105, 180, 255]) - .without_mbrs(true) - .with_saveout("DB"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/deim/README.md b/examples/deim/README.md new file mode 100644 index 0000000..08e833c --- /dev/null +++ b/examples/deim/README.md @@ -0,0 +1,7 @@ +## Quick Start + +```shell +cargo run -r --example deim +``` + + diff --git a/examples/deim/main.rs b/examples/deim/main.rs new file mode 100644 index 0000000..cf8d4e5 --- /dev/null +++ b/examples/deim/main.rs @@ -0,0 +1,28 @@ +use anyhow::Result; +use usls::{models::RTDETR, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // options + let options = Options::deim_dfine_s_coco().commit()?; + let mut model = RTDETR::new(options)?; + + // load + let x = [DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let y = model.forward(&x)?; + println!("{:?}", y); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/deit/README.md b/examples/deit/README.md new file mode 100644 index 0000000..962781f --- /dev/null +++ b/examples/deit/README.md @@ -0,0 +1,7 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example deit -- --device cuda --dtype fp16 +``` + + diff --git a/examples/deit/main.rs b/examples/deit/main.rs new file mode 100644 index 0000000..98d7c12 --- /dev/null +++ b/examples/deit/main.rs @@ -0,0 +1,52 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::deit_tiny_distill() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index d339ff3..f1deeea 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -1,24 +1,26 @@ +use anyhow::Result; use usls::{models::DepthAnything, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // options - let options = Options::default() - // .with_model("depth-anything/v1-s-dyn.onnx")? - .with_model("depth-anything/v2-s.onnx")? - .with_ixx(0, 2, (384, 512, 1024).into()) - .with_ixx(0, 3, (384, 512, 1024).into()); +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // build model + let options = Options::depth_anything_v2_small().commit()?; let mut model = DepthAnything::new(options)?; // load let x = [DataLoader::try_read("images/street.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_colormap("Turbo") - .with_saveout("Depth-Anything"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/depth-pro/README.md b/examples/depth-pro/README.md new file mode 100644 index 0000000..52c1418 --- /dev/null +++ b/examples/depth-pro/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example depth-pro -- --device cuda +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/depth-pro/demo-depth-pro.png) diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs index eb72a9a..8919f93 100644 --- a/examples/depth-pro/main.rs +++ b/examples/depth-pro/main.rs @@ -1,25 +1,47 @@ +use anyhow::Result; use usls::{models::DepthPro, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // options - let options = Options::default() - .with_model("depth-pro/q4f16.onnx")? // bnb4, f16 - .with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 - .with_ixx(0, 1, 3.into()) // channel - .with_ixx(0, 2, 1536.into()) // height - .with_ixx(0, 3, 1536.into()); // width +#[derive(argh::FromArgs)] +/// BLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// dtype + #[argh(option, default = "String::from(\"q4f16\")")] + dtype: String, + + /// source image + #[argh(option, default = "String::from(\"images/street.jpg\")")] + source: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // model + let options = Options::depth_pro() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = DepthPro::new(options)?; // load - let x = [DataLoader::try_read("images/street.jpg")?]; + let x = [DataLoader::try_read(&args.source)?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_colormap("Turbo") - .with_saveout("Depth-Pro"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/dinov2/main.rs b/examples/dinov2/main.rs index 4cc7732..5168785 100644 --- a/examples/dinov2/main.rs +++ b/examples/dinov2/main.rs @@ -1,40 +1,25 @@ -use usls::{models::Dinov2, DataLoader, Options}; +use anyhow::Result; +use usls::{models::DINOv2, DataLoader, Options}; -fn main() -> Result<(), Box> { - // build model - let options = Options::default() - .with_model("dinov2/s-dyn.onnx")? - .with_ixx(0, 2, 224.into()) - .with_ixx(0, 3, 224.into()); - let mut model = Dinov2::new(options)?; - let x = [DataLoader::try_read("images/bus.jpg")?]; - let y = model.run(&x)?; - println!("{y:?}"); +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); - // TODO: - // query from vector - // let ys = model.query_from_vec( - // "./assets/bus.jpg", - // &[ - // "./examples/dinov2/images/bus.jpg", - // "./examples/dinov2/images/1.jpg", - // "./examples/dinov2/images/2.jpg", - // ], - // Metric::L2, - // )?; + // images + let xs = [ + DataLoader::try_read("./assets/bus.jpg")?, + DataLoader::try_read("./assets/bus.jpg")?, + ]; - // or query from folder - // let ys = model.query_from_folder("./assets/bus.jpg", "./examples/dinov2/images", Metric::IP)?; + // model + let options = Options::dinov2_small().with_batch_size(xs.len()).commit()?; + let mut model = DINOv2::new(options)?; - // results - // for (i, y) in ys.iter().enumerate() { - // println!( - // "Top-{:<3}{:.7} {}", - // i + 1, - // y.1, - // y.2.canonicalize()?.display() - // ); - // } + // encode images + let y = model.encode_images(&xs)?; + println!("Feat shape: {:?}", y.shape()); Ok(()) } diff --git a/examples/doclayout-yolo/README.md b/examples/doclayout-yolo/README.md new file mode 100644 index 0000000..b9b233f --- /dev/null +++ b/examples/doclayout-yolo/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example doclayout-yolo -- --device cuda +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/yolo/demo-doclayout-yolo.png) diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs new file mode 100644 index 0000000..99a945b --- /dev/null +++ b/examples/doclayout-yolo/main.rs @@ -0,0 +1,42 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::doclayout_yolo_docstructbench() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = [DataLoader::try_read("images/academic.jpg")?]; + + // run + let ys = model.forward(&xs)?; + // println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout("doclayout-yolo"); + annotator.annotate(&xs, &ys); + + model.summary(); + + Ok(()) +} diff --git a/examples/fast/README.md b/examples/fast/README.md new file mode 100644 index 0000000..89227df --- /dev/null +++ b/examples/fast/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```shell +cargo run -r --example fast +``` + diff --git a/examples/fast/main.rs b/examples/fast/main.rs new file mode 100644 index 0000000..84872d1 --- /dev/null +++ b/examples/fast/main.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use usls::{models::DB, Annotator, DataLoader, Options, Scale}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"t\")")] + scale: String, + + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = match args.scale.as_str().try_into()? { + Scale::T => Options::fast_tiny(), + Scale::S => Options::fast_small(), + Scale::B => Options::fast_base(), + _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), + }; + let mut model = DB::new( + options + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?, + )?; + + // load image + let x = DataLoader::try_read_batch(&[ + "images/table.png", + "images/table1.jpg", + "images/table2.png", + "images/table-ch.jpg", + "images/db.png", + "images/street.jpg", + ])?; + + // run + let y = model.forward(&x)?; + + // annotate + let annotator = Annotator::default() + .without_bboxes(true) + .without_mbrs(true) + .with_polygons_alpha(60) + .with_contours_color([255, 105, 180, 255]) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/fastsam/README.md b/examples/fastsam/README.md new file mode 100644 index 0000000..b2984e1 --- /dev/null +++ b/examples/fastsam/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example fastsam -- --device cuda +``` diff --git a/examples/fastsam/main.rs b/examples/fastsam/main.rs new file mode 100644 index 0000000..0050fda --- /dev/null +++ b/examples/fastsam/main.rs @@ -0,0 +1,45 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"fp16\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::fastsam_s() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = DataLoader::try_read_batch(&["./assets/bus.jpg"])?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default() + .without_masks(true) + .with_bboxes_thickness(3) + .with_saveout("fastsam"); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/fastvit/README.md b/examples/fastvit/README.md new file mode 100644 index 0000000..ca00fdf --- /dev/null +++ b/examples/fastvit/README.md @@ -0,0 +1,13 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example mobileone -- --device cuda --dtype fp16 +``` + + +```shell +0: Y { Probs: { Top5: [(263, 0.6109131, Some("Pembroke, Pembroke Welsh corgi")), (264, 0.2062352, Some("Cardigan, Cardigan Welsh corgi")), (231, 0.028572788, Some("collie")), (273, 0.015174894, Some("dingo, warrigal, warragal, Canis dingo")), (248, 0.014367299, Some("Eskimo dog, husky"))] } } +1: Y { Probs: { Top5: [(284, 0.9907692, Some("siamese cat, Siamese")), (285, 0.0015794479, Some("Egyptian cat")), (174, 0.0015189401, Some("Norwegian elkhound, elkhound")), (225, 0.00031838714, Some("malinois")), (17, 0.00027021166, Some("jay"))] } } +2: Y { Probs: { Top5: [(387, 0.94238573, Some("lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens")), (368, 0.0029994072, Some("gibbon, Hylobates lar")), (277, 0.0016564301, Some("red fox, Vulpes vulpes")), (356, 0.0015081967, Some("weasel")), (295, 0.001427932, Some("American black bear, black bear, Ursus americanus, Euarctos americanus"))] } } + +``` diff --git a/examples/fastvit/main.rs b/examples/fastvit/main.rs new file mode 100644 index 0000000..cb93886 --- /dev/null +++ b/examples/fastvit/main.rs @@ -0,0 +1,57 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::fastvit_t8_distill() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // results + for (i, y) in ys.iter().enumerate() { + println!("{}: {:?}", i, y); + } + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/florence2/README.md b/examples/florence2/README.md new file mode 100644 index 0000000..6078515 --- /dev/null +++ b/examples/florence2/README.md @@ -0,0 +1,30 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example florence2 -- --device cuda --scale base --dtype fp16 +``` + + +```Shell +Task: Caption(0) +Ys([Y { Texts: [Text("A green car parked in front of a yellow building.")] }, Y { Texts: [Text("A group of people walking down a street next to a bus.")] }]) + +Task: Caption(1) +Ys([Y { Texts: [Text("The image shows a green car parked in front of a yellow building with two brown doors. The car is on the road, and the building has a wall and a tree in the background.")] }, Y { Texts: [Text("The image shows a group of people walking down a street next to a bus, with a building in the background. The bus is likely part of the World Electric Emission Bus, which is a new bus that will be launched in Madrid. The people are walking on the road, and there are trees and a sign board to the left of the bus.")] }]) + +Task: Caption(2) +Ys([Y { Texts: [Text("The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is a light blue color with silver rims and appears to be in good condition. The building has a sloping roof and is painted in a bright yellow color. The sky is blue and there are trees in the background. The overall mood of the image is peaceful and serene.")] }, Y { Texts: [Text("The image shows a blue and white bus with the logo of the Brazilian football club, Cero Emisiones, on the side. The bus is parked on a street with a building in the background. There are several people walking on the sidewalk in front of the bus, some of them are carrying bags and one person is holding a camera. The sky is blue and there are trees and a traffic light visible in the top right corner of the image. The image appears to be taken during the day.")] }]) +``` + +## Results + +| Task | Demo | +| -----| ------| +|Caption-To-Phrase-Grounding | | +| Ocr-With-Region | | +| Dense-Region-Caption | | +| Object-Detection | | +| Region-Proposal | | +| Referring-Expression-Segmentation | | + + diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 07cc7d1..7248faf 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -1,157 +1,176 @@ -use usls::{models::Florence2, Annotator, DataLoader, Options, Task}; - -fn main() -> Result<(), Box> { - let batch_size = 3; - - // vision encoder - let options_vision_encoder = Options::default() - .with_model("florence2/base-vision-encoder-f16.onnx")? - .with_ixx(0, 2, (512, 768, 800).into()) - .with_ixx(0, 3, 768.into()) - .with_ixx(0, 0, (1, batch_size as _, 8).into()); - - // text embed - let options_text_embed = Options::default() - .with_model("florence2/base-embed-tokens-f16.onnx")? - .with_tokenizer("florence2/tokenizer.json")? - .with_batch(batch_size); - - // transformer encoder - let options_encoder = Options::default() - .with_model("florence2/base-encoder-f16.onnx")? - .with_batch(batch_size); - - // transformer decoder - let options_decoder = Options::default() - .with_model("florence2/base-decoder-f16.onnx")? - .with_batch(batch_size); - - // transformer decoder merged - let options_decoder_merged = Options::default() - .with_model("florence2/base-decoder-merged-f16.onnx")? - .with_batch(batch_size); - - // build model - let mut model = Florence2::new( - options_vision_encoder, - options_text_embed, - options_encoder, - options_decoder, - options_decoder_merged, - )?; - - // load images - let xs = [ - // DataLoader::try_read("florence2/car.jpg")?, // for testing region-related tasks - DataLoader::try_read("florence2/car.jpg")?, - // DataLoader::try_read("images/db.png")?, - DataLoader::try_read("assets/bus.jpg")?, - ]; - - // region-related tasks - let quantizer = usls::Quantizer::default(); - // let coords = [449., 270., 556., 372.]; // wheel - let coords = [31., 156., 581., 373.]; // car - let (width_car, height_car) = (xs[0].width(), xs[0].height()); - let quantized_coords = quantizer.quantize(&coords, (width_car as _, height_car as _)); - - // run with tasks - let ys = model.run_with_tasks( - &xs, - &[ - // w/ inputs - Task::Caption(0), - Task::Caption(1), - Task::Caption(2), - Task::Ocr, - Task::OcrWithRegion, - Task::RegionProposal, - Task::ObjectDetection, - Task::DenseRegionCaption, - // w/o inputs - Task::OpenSetDetection("a vehicle".into()), - Task::CaptionToPhraseGrounding( - "A vehicle with two wheels parked in front of a building.".into(), - ), - Task::ReferringExpressionSegmentation("a vehicle".into()), - Task::RegionToSegmentation( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - Task::RegionToCategory( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - Task::RegionToDescription( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - ], - )?; - - // annotator - let annotator = Annotator::new() - .without_bboxes_conf(true) - .with_bboxes_thickness(3) - .with_saveout_subs(&["Florence2"]); - for (task, ys_) in ys.iter() { - match task { - Task::Caption(_) - | Task::Ocr - | Task::RegionToCategory(..) - | Task::RegionToDescription(..) => { - println!("Task: {:?}\n{:?}\n", task, ys_) - } - Task::DenseRegionCaption => { - let annotator = annotator.clone().with_saveout("Dense-Region-Caption"); - annotator.annotate(&xs, ys_); - } - Task::RegionProposal => { - let annotator = annotator - .clone() - .without_bboxes_name(false) - .with_saveout("Region-Proposal"); - - annotator.annotate(&xs, ys_); - } - Task::ObjectDetection => { - let annotator = annotator.clone().with_saveout("Object-Detection"); - annotator.annotate(&xs, ys_); - } - Task::OpenSetDetection(_) => { - let annotator = annotator.clone().with_saveout("Open-Set-Detection"); - annotator.annotate(&xs, ys_); - } - Task::CaptionToPhraseGrounding(_) => { - let annotator = annotator - .clone() - .with_saveout("Caption-To-Phrase-Grounding"); - annotator.annotate(&xs, ys_); - } - Task::ReferringExpressionSegmentation(_) => { - let annotator = annotator - .clone() - .with_saveout("Referring-Expression-Segmentation"); - annotator.annotate(&xs, ys_); - } - Task::RegionToSegmentation(..) => { - let annotator = annotator.clone().with_saveout("Region-To-Segmentation"); - annotator.annotate(&xs, ys_); - } - Task::OcrWithRegion => { - let annotator = annotator.clone().with_saveout("Ocr-With-Region"); - annotator.annotate(&xs, ys_); - } - - _ => (), - } - } - - Ok(()) -} +use anyhow::Result; +use usls::{models::Florence2, Annotator, DataLoader, Options, Scale, Task}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"base\")")] + scale: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // load images + let xs = [ + DataLoader::try_read("images/green-car.jpg")?, + DataLoader::try_read("assets/bus.jpg")?, + ]; + + // build model + let ( + options_vision_encoder, + options_text_embed, + options_encoder, + options_decoder, + options_decoder_merged, + ) = match args.scale.as_str().try_into()? { + Scale::B => ( + Options::florence2_visual_encoder_base(), + Options::florence2_textual_embed_base(), + Options::florence2_texual_encoder_base(), + Options::florence2_texual_decoder_base(), + Options::florence2_texual_decoder_merged_base(), + ), + Scale::L => todo!(), + _ => anyhow::bail!("Unsupported Florence2 scale."), + }; + + let mut model = Florence2::new( + options_vision_encoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_text_embed + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_encoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder_merged + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + )?; + + // tasks + let tasks = [ + // w inputs + Task::Caption(0), + Task::Caption(1), + Task::Caption(2), + Task::Ocr, + // Task::OcrWithRegion, + Task::RegionProposal, + Task::ObjectDetection, + Task::DenseRegionCaption, + // w/o inputs + Task::OpenSetDetection("a vehicle"), + Task::CaptionToPhraseGrounding("A vehicle with two wheels parked in front of a building."), + Task::ReferringExpressionSegmentation("a vehicle"), + Task::RegionToSegmentation( + // 31, 156, 581, 373, // car + 449, 270, 556, 372, // wheel + ), + Task::RegionToCategory( + // 31, 156, 581, 373, + 449, 270, 556, 372, + ), + Task::RegionToDescription( + // 31, 156, 581, 373, + 449, 270, 556, 372, + ), + ]; + + // annotator + let annotator = Annotator::new() + .without_bboxes_conf(true) + .with_bboxes_thickness(3) + .with_saveout_subs(&["Florence2"]); + + // inference + for task in tasks.iter() { + let ys = model.forward(&xs, task)?; + + // annotate + match task { + Task::Caption(_) + | Task::Ocr + | Task::RegionToCategory(..) + | Task::RegionToDescription(..) => { + println!("Task: {:?}\n{:?}\n", task, &ys) + } + Task::DenseRegionCaption => { + let annotator = annotator.clone().with_saveout("Dense-Region-Caption"); + annotator.annotate(&xs, &ys); + } + Task::RegionProposal => { + let annotator = annotator + .clone() + .without_bboxes_name(false) + .with_saveout("Region-Proposal"); + + annotator.annotate(&xs, &ys); + } + Task::ObjectDetection => { + let annotator = annotator.clone().with_saveout("Object-Detection"); + annotator.annotate(&xs, &ys); + } + Task::OpenSetDetection(_) => { + let annotator = annotator.clone().with_saveout("Open-Set-Detection"); + annotator.annotate(&xs, &ys); + } + Task::CaptionToPhraseGrounding(_) => { + let annotator = annotator + .clone() + .with_saveout("Caption-To-Phrase-Grounding"); + annotator.annotate(&xs, &ys); + } + Task::ReferringExpressionSegmentation(_) => { + let annotator = annotator + .clone() + .with_saveout("Referring-Expression-Segmentation"); + annotator.annotate(&xs, &ys); + } + Task::RegionToSegmentation(..) => { + let annotator = annotator.clone().with_saveout("Region-To-Segmentation"); + annotator.annotate(&xs, &ys); + } + Task::OcrWithRegion => { + let annotator = annotator.clone().with_saveout("Ocr-With-Region"); + annotator.annotate(&xs, &ys); + } + + _ => (), + } + } + + model.summary(); + + Ok(()) +} diff --git a/examples/grounding-dino/README.md b/examples/grounding-dino/README.md index a94cb0b..f97321f 100644 --- a/examples/grounding-dino/README.md +++ b/examples/grounding-dino/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r --example grounding-dino +cargo run -r -F cuda --example grounding-dino -- --device cuda --dtype fp16 ``` diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index 2ceb61c..78c6493 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -1,41 +1,72 @@ +use anyhow::Result; use usls::{models::GroundingDINO, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - let opts = Options::default() - .with_ixx(0, 0, (1, 1, 4).into()) - .with_ixx(0, 2, (640, 800, 1200).into()) - .with_ixx(0, 3, (640, 1200, 1200).into()) - // .with_i10((1, 1, 4).into()) - // .with_i11((256, 256, 512).into()) - // .with_i20((1, 1, 4).into()) - // .with_i21((256, 256, 512).into()) - // .with_i30((1, 1, 4).into()) - // .with_i31((256, 256, 512).into()) - // .with_i40((1, 1, 4).into()) - // .with_i41((256, 256, 512).into()) - // .with_i50((1, 1, 4).into()) - // .with_i51((256, 256, 512).into()) - // .with_i52((256, 256, 512).into()) - .with_model("grounding-dino/swint-ogc-dyn-u8.onnx")? // TODO: current onnx model does not support bs > 1 - // .with_model("grounding-dino/swint-ogc-dyn-f32.onnx")? - .with_tokenizer("grounding-dino/tokenizer.json")? - .with_confs(&[0.2]) - .with_profile(false); - let mut model = GroundingDINO::new(opts)?; - - // Load images and set class names - let x = [DataLoader::try_read("images/bus.jpg")?]; - let texts = [ - "person", "hand", "shoes", "bus", "dog", "cat", "sign", "tie", "monitor", "window", - "glasses", "tree", "head", - ]; - - // Run and annotate - let y = model.run(&x, &texts)?; +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] + source: Vec, + + /// open class names + #[argh( + option, + default = "vec![ + String::from(\"person\"), + String::from(\"hand\"), + String::from(\"shoes\"), + String::from(\"bus\"), + String::from(\"dog\"), + String::from(\"cat\"), + String::from(\"sign\"), + String::from(\"tie\"), + String::from(\"monitor\"), + String::from(\"glasses\"), + String::from(\"tree\"), + String::from(\"head\"), + ]" + )] + labels: Vec, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + let options = Options::grounding_dino_tiny() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) + .commit()?; + + let mut model = GroundingDINO::new(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate let annotator = Annotator::default() .with_bboxes_thickness(4) - .with_saveout("GroundingDINO"); - annotator.annotate(&x, &y); + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + // summary + model.summary(); Ok(()) } diff --git a/examples/hub/README.md b/examples/hub/README.md new file mode 100644 index 0000000..7cddfbc --- /dev/null +++ b/examples/hub/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +RUST_LOG=usls=info cargo run -r --example hub +``` diff --git a/examples/hub/main.rs b/examples/hub/main.rs new file mode 100644 index 0000000..45cc7b2 --- /dev/null +++ b/examples/hub/main.rs @@ -0,0 +1,26 @@ +use usls::Hub; + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // 1. Download from default github release + let path = Hub::default().try_fetch("images/bus.jpg")?; + println!("Fetch one image: {:?}", path); + + // 2. Download from specific github release url + let path = Hub::default() + .try_fetch("https://github.com/jamjamjon/assets/releases/download/images/bus.jpg")?; + println!("Fetch one file: {:?}", path); + + // 3. Fetch tags and files + let hub = Hub::default().with_owner("jamjamjon").with_repo("usls"); + for tag in hub.tags().iter() { + let files = hub.files(tag); + println!("{} => {:?}", tag, files); // Should be empty + } + + Ok(()) +} diff --git a/examples/linknet/README.md b/examples/linknet/README.md new file mode 100644 index 0000000..89227df --- /dev/null +++ b/examples/linknet/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```shell +cargo run -r --example fast +``` + diff --git a/examples/linknet/main.rs b/examples/linknet/main.rs new file mode 100644 index 0000000..4fc3841 --- /dev/null +++ b/examples/linknet/main.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use usls::{models::DB, Annotator, DataLoader, Options, Scale}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"t\")")] + scale: String, + + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = match args.scale.as_str().try_into()? { + Scale::T => Options::linknet_r18(), + Scale::S => Options::linknet_r34(), + Scale::B => Options::linknet_r50(), + _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), + }; + let mut model = DB::new( + options + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?, + )?; + + // load image + let x = DataLoader::try_read_batch(&[ + "images/table.png", + "images/table1.jpg", + "images/table2.png", + "images/table-ch.jpg", + "images/db.png", + "images/street.jpg", + ])?; + + // run + let y = model.forward(&x)?; + + // annotate + let annotator = Annotator::default() + .without_bboxes(true) + .without_mbrs(true) + .with_polygons_alpha(60) + .with_contours_color([255, 105, 180, 255]) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/mobileone/README.md b/examples/mobileone/README.md new file mode 100644 index 0000000..ca00fdf --- /dev/null +++ b/examples/mobileone/README.md @@ -0,0 +1,13 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example mobileone -- --device cuda --dtype fp16 +``` + + +```shell +0: Y { Probs: { Top5: [(263, 0.6109131, Some("Pembroke, Pembroke Welsh corgi")), (264, 0.2062352, Some("Cardigan, Cardigan Welsh corgi")), (231, 0.028572788, Some("collie")), (273, 0.015174894, Some("dingo, warrigal, warragal, Canis dingo")), (248, 0.014367299, Some("Eskimo dog, husky"))] } } +1: Y { Probs: { Top5: [(284, 0.9907692, Some("siamese cat, Siamese")), (285, 0.0015794479, Some("Egyptian cat")), (174, 0.0015189401, Some("Norwegian elkhound, elkhound")), (225, 0.00031838714, Some("malinois")), (17, 0.00027021166, Some("jay"))] } } +2: Y { Probs: { Top5: [(387, 0.94238573, Some("lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens")), (368, 0.0029994072, Some("gibbon, Hylobates lar")), (277, 0.0016564301, Some("red fox, Vulpes vulpes")), (356, 0.0015081967, Some("weasel")), (295, 0.001427932, Some("American black bear, black bear, Ursus americanus, Euarctos americanus"))] } } + +``` diff --git a/examples/mobileone/main.rs b/examples/mobileone/main.rs new file mode 100644 index 0000000..36238c2 --- /dev/null +++ b/examples/mobileone/main.rs @@ -0,0 +1,57 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::mobileone_s0() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // results + for (i, y) in ys.iter().enumerate() { + println!("{}: {:?}", i, y); + } + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/modnet/main.rs b/examples/modnet/main.rs index 660ded5..39691b3 100644 --- a/examples/modnet/main.rs +++ b/examples/modnet/main.rs @@ -1,22 +1,24 @@ use usls::{models::MODNet, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + // build model - let options = Options::default() - .with_model("modnet/dyn-f32.onnx")? - .with_ixx(0, 2, (416, 512, 800).into()) - .with_ixx(0, 3, (416, 512, 800).into()); + let options = Options::modnet_photographic().commit()?; let mut model = MODNet::new(options)?; // load image - let x = [DataLoader::try_read("images/liuyifei.png")?]; + let xs = [DataLoader::try_read("images/liuyifei.png")?]; // run - let y = model.run(&x)?; + let ys = model.forward(&xs)?; // annotate - let annotator = Annotator::default().with_saveout("MODNet"); - annotator.annotate(&x, &y); + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); Ok(()) } diff --git a/examples/picodet-layout/README.md b/examples/picodet-layout/README.md new file mode 100644 index 0000000..8e29d70 --- /dev/null +++ b/examples/picodet-layout/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r --example picodet-layout +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/picodet/demo-layout-1x.png) diff --git a/examples/picodet-layout/main.rs b/examples/picodet-layout/main.rs new file mode 100644 index 0000000..fca0bcb --- /dev/null +++ b/examples/picodet-layout/main.rs @@ -0,0 +1,31 @@ +use anyhow::Result; +use usls::{models::PicoDet, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // options + let options = Options::picodet_layout_1x() + // picodet_l_layout_3cls() + // picodet_l_layout_17cls() + .commit()?; + let mut model = PicoDet::new(options)?; + + // load + let xs = [DataLoader::try_read("images/academic.jpg")?]; + + // annotator + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/rtdetr/README.md b/examples/rtdetr/README.md new file mode 100644 index 0000000..711c097 --- /dev/null +++ b/examples/rtdetr/README.md @@ -0,0 +1,17 @@ +## Quick Start + +```shell +cargo run -r --example rtdetr +``` + +## Results + +``` +[Bboxes]: Found 5 objects +0: Bbox { xyxy: [47.969677, 397.81808, 246.22426, 904.8823], class_id: 0, name: Some("person"), confidence: 0.94432133 } +1: Bbox { xyxy: [668.0796, 399.28854, 810.3779, 880.7412], class_id: 0, name: Some("person"), confidence: 0.93386495 } +2: Bbox { xyxy: [20.852705, 229.30482, 807.43494, 729.51196], class_id: 5, name: Some("bus"), confidence: 0.9319465 } +3: Bbox { xyxy: [223.28226, 405.37265, 343.92603, 859.50366], class_id: 0, name: Some("person"), confidence: 0.9130827 } +4: Bbox { xyxy: [0.0, 552.6165, 65.99908, 868.00525], class_id: 0, name: Some("person"), confidence: 0.7910869 } + +``` diff --git a/examples/rtdetr/main.rs b/examples/rtdetr/main.rs new file mode 100644 index 0000000..590b218 --- /dev/null +++ b/examples/rtdetr/main.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use usls::{models::RTDETR, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + // options + let options = Options::rtdetr_v2_s_coco() + // rtdetr_v1_r18vd_coco() + // rtdetr_v2_ms_coco() + // rtdetr_v2_m_coco() + // rtdetr_v2_l_coco() + // rtdetr_v2_x_coco() + .commit()?; + let mut model = RTDETR::new(options)?; + + // load + let xs = [DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let ys = model.forward(&xs)?; + + // extract bboxes + for y in ys.iter() { + if let Some(bboxes) = y.bboxes() { + println!("[Bboxes]: Found {} objects", bboxes.len()); + for (i, bbox) in bboxes.iter().enumerate() { + println!("{}: {:?}", i, bbox) + } + } + } + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index aae1706..efe198a 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,25 +1,26 @@ +use anyhow::Result; use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETONS_16}; -fn main() -> Result<(), Box> { +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + // build model - let options = Options::default() - .with_model("rtmo/s-dyn.onnx")? - .with_nk(17) - .with_confs(&[0.3]) - .with_kconfs(&[0.5]); - let mut model = RTMO::new(options)?; + let mut model = RTMO::new(Options::rtmo_s().commit()?)?; // load image - let x = [DataLoader::try_read("images/bus.jpg")?]; + let xs = [DataLoader::try_read("images/bus.jpg")?]; // run - let y = model.run(&x)?; + let ys = model.forward(&xs)?; // annotate let annotator = Annotator::default() - .with_saveout("RTMO") + .with_saveout(model.spec()) .with_skeletons(&COCO_SKELETONS_16); - annotator.annotate(&x, &y); + annotator.annotate(&xs, &ys); Ok(()) } diff --git a/examples/sam/README.md b/examples/sam/README.md index 92af792..34db1e3 100644 --- a/examples/sam/README.md +++ b/examples/sam/README.md @@ -3,19 +3,18 @@ ```Shell # SAM -cargo run -r --example sam +cargo run -r -F cuda --example sam -- --device cuda --kind sam # MobileSAM -cargo run -r --example sam -- --kind mobile-sam +cargo run -r -F cuda --example sam -- --device cuda --kind mobile-sam # EdgeSAM -cargo run -r --example sam -- --kind edge-sam +cargo run -r -F cuda --example sam -- --device cuda --kind edge-sam # SAM-HQ -cargo run -r --example sam -- --kind sam-hq +cargo run -r -F cuda --example sam -- --device cuda --kind sam-hq ``` - ## Results ![](https://github.com/jamjamjon/assets/releases/download/sam/demo-car.png) diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 72ed218..ca009c7 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -1,97 +1,73 @@ -use clap::Parser; - +use anyhow::Result; use usls::{ models::{SamKind, SamPrompt, SAM}, - Annotator, DataLoader, Options, + Annotator, DataLoader, Options, Scale, }; -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -pub struct Args { - #[arg(long, value_enum, default_value_t = SamKind::Sam)] - pub kind: SamKind, +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, - #[arg(long, default_value_t = 0)] - pub device_id: usize, + /// scale + #[argh(option, default = "String::from(\"t\")")] + scale: String, - #[arg(long)] - pub use_low_res_mask: bool, + /// SAM kind + #[argh(option, default = "String::from(\"sam\")")] + kind: String, } -fn main() -> Result<(), Box> { - let args = Args::parse(); - - // Options - let (options_encoder, options_decoder, saveout) = match args.kind { - SamKind::Sam => { - let options_encoder = Options::default() - // .with_model("sam/sam-vit-b-encoder.onnx")?; - .with_model("sam/sam-vit-b-encoder-u8.onnx")?; +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); - let options_decoder = Options::default() - .with_sam_kind(SamKind::Sam) - // .with_model("sam/sam-vit-b-decoder.onnx")?; - // .with_model("sam/sam-vit-b-decoder-singlemask.onnx")?; - .with_model("sam/sam-vit-b-decoder-u8.onnx")?; - (options_encoder, options_decoder, "SAM") - } - SamKind::Sam2 => { - let options_encoder = Options::default() - // .with_model("sam/sam2-hiera-tiny-encoder.onnx")?; - // .with_model("sam/sam2-hiera-small-encoder.onnx")?; - .with_model("sam/sam2-hiera-base-plus-encoder.onnx")?; - let options_decoder = Options::default() - .with_sam_kind(SamKind::Sam2) - // .with_model("sam/sam2-hiera-tiny-decoder.onnx")?; - // .with_model("sam/sam2-hiera-small-decoder.onnx")?; - .with_model("sam/sam2-hiera-base-plus-decoder.onnx")?; - (options_encoder, options_decoder, "SAM2") - } - SamKind::MobileSam => { - let options_encoder = - Options::default().with_model("sam/mobile-sam-vit-t-encoder.onnx")?; - - let options_decoder = Options::default() - .with_sam_kind(SamKind::MobileSam) - .with_model("sam/mobile-sam-vit-t-decoder.onnx")?; - (options_encoder, options_decoder, "Mobile-SAM") - } - SamKind::SamHq => { - let options_encoder = Options::default().with_model("sam/sam-hq-vit-t-encoder.onnx")?; + let args: Args = argh::from_env(); + // Build model + let (options_encoder, options_decoder) = match args.kind.as_str().try_into()? { + SamKind::Sam => ( + Options::sam_v1_base_encoder(), + Options::sam_v1_base_decoder(), + ), + SamKind::Sam2 => match args.scale.as_str().try_into()? { + Scale::T => (Options::sam2_tiny_encoder(), Options::sam2_tiny_decoder()), + Scale::S => (Options::sam2_small_encoder(), Options::sam2_small_decoder()), + Scale::B => ( + Options::sam2_base_plus_encoder(), + Options::sam2_base_plus_decoder(), + ), + _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), + }, - let options_decoder = Options::default() - .with_sam_kind(SamKind::SamHq) - .with_model("sam/sam-hq-vit-t-decoder.onnx")?; - (options_encoder, options_decoder, "SAM-HQ") - } - SamKind::EdgeSam => { - let options_encoder = Options::default().with_model("sam/edge-sam-3x-encoder.onnx")?; - let options_decoder = Options::default() - .with_sam_kind(SamKind::EdgeSam) - .with_model("sam/edge-sam-3x-decoder.onnx")?; - (options_encoder, options_decoder, "Edge-SAM") - } + SamKind::MobileSam => ( + Options::mobile_sam_tiny_encoder(), + Options::mobile_sam_tiny_decoder(), + ), + SamKind::SamHq => ( + Options::sam_hq_tiny_encoder(), + Options::sam_hq_tiny_decoder(), + ), + SamKind::EdgeSam => ( + Options::edge_sam_3x_encoder(), + Options::edge_sam_3x_decoder(), + ), }; - let options_encoder = options_encoder - .with_cuda(args.device_id) - .with_ixx(0, 2, (800, 1024, 1024).into()) - .with_ixx(0, 3, (800, 1024, 1024).into()); - let options_decoder = options_decoder - .with_cuda(args.device_id) - .use_low_res_mask(args.use_low_res_mask) - .with_find_contours(true); - // Build model + let options_encoder = options_encoder + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_decoder = options_decoder.commit()?; let mut model = SAM::new(options_encoder, options_decoder)?; // Load image - let xs = [ - DataLoader::try_read("images/truck.jpg")?, - // DataLoader::try_read("images/dog.jpg")?, - ]; + let xs = [DataLoader::try_read("images/truck.jpg")?]; // Build annotator - let annotator = Annotator::default().with_saveout(saveout); + let annotator = Annotator::default().with_saveout(model.spec()); // Prompt let prompts = vec![ @@ -102,7 +78,7 @@ fn main() -> Result<(), Box> { ]; // Run & Annotate - let ys = model.run(&xs, &prompts)?; + let ys = model.forward(&xs, &prompts)?; annotator.annotate(&xs, &ys); Ok(()) diff --git a/examples/sapiens/README.md b/examples/sapiens/README.md index 6bf5cfe..7699915 100644 --- a/examples/sapiens/README.md +++ b/examples/sapiens/README.md @@ -1,10 +1,9 @@ ## Quick Start ```shell -cargo run -r --example sapiens +cargo run -r -F cuda --example sapiens -- --device cuda ``` - ## Results ![](https://github.com/jamjamjon/assets/releases/download/sapiens/demo.png) diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs index 111d90f..08d3167 100644 --- a/examples/sapiens/main.rs +++ b/examples/sapiens/main.rs @@ -1,27 +1,38 @@ -use usls::{ - models::{Sapiens, SapiensTask}, - Annotator, DataLoader, Options, BODY_PARTS_28, -}; +use anyhow::Result; +use usls::{models::Sapiens, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); // build - let options = Options::default() - .with_model("sapiens/seg-0.3b-dyn.onnx")? - .with_sapiens_task(SapiensTask::Seg) - .with_names(&BODY_PARTS_28); + let options = Options::sapiens_seg_0_3b() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = Sapiens::new(options)?; // load let x = [DataLoader::try_read("images/paul-george.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .without_masks(true) - .with_polygons_name(false) - .with_saveout("Sapiens"); + .with_polygons_name(true) + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/slanet/README.md b/examples/slanet/README.md new file mode 100644 index 0000000..9ee499a --- /dev/null +++ b/examples/slanet/README.md @@ -0,0 +1,9 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example slanet -- --device cuda +``` + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/slanet/demo.png) diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs new file mode 100644 index 0000000..9707c53 --- /dev/null +++ b/examples/slanet/main.rs @@ -0,0 +1,48 @@ +use anyhow::Result; +use usls::{models::SLANet, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// source + #[argh(option, default = "String::from(\"images/table.png\")")] + source: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let options = Options::slanet_lcnet_v2_mobile_ch() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = SLANet::new(options)?; + + // load + let xs = DataLoader::try_read_batch(&[args.source])?; + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_keypoints_radius(2) + .with_skeletons(&[(0, 1), (1, 2), (2, 3), (3, 0)]) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + // summary + model.summary(); + + Ok(()) +} diff --git a/examples/svtr/README.md b/examples/svtr/README.md index cc192bc..82c10c5 100644 --- a/examples/svtr/README.md +++ b/examples/svtr/README.md @@ -1,29 +1,21 @@ ## Quick Start ```shell -cargo run -r --example svtr +cargo run -r -F cuda --example svtr -- --device cuda ``` -### Speed test - -| Model | Width | TensorRT
f16
batch=1
(ms) | TensorRT
f32
batch=1
(ms) | CUDA
f32
batch=1
(ms) | -| --------------------------- | :---: | :--------------------------------------: | :--------------------------------------: | :----------------------------------: | -| ppocr-v4-server-svtr-ch-dyn | 1500 | 4.2116 | 13.0013 | 20.8673 | -| ppocr-v4-svtr-ch-dyn | 1500 | 2.0435 | 3.1959 | 10.1750 | -| ppocr-v3-svtr-ch-dyn | 1500 | 1.8596 | 2.9401 | 6.8210 | - -***Test on RTX3060*** - ## Results ```shell -["./examples/svtr/images/5.png"]: Some(["are closely jointed. Some examples are illustrated in Fig.7."]) -["./examples/svtr/images/6.png"]: Some(["小菊儿胡同71号"]) -["./examples/svtr/images/4.png"]: Some(["我在南锣鼓捣猫呢"]) -["./examples/svtr/images/1.png"]: Some(["你有这么高速运转的机械进入中国,记住我给出的原理"]) -["./examples/svtr/images/2.png"]: Some(["冀B6G000"]) -["./examples/svtr/images/9.png"]: Some(["from the background, but also separate text instances which"]) -["./examples/svtr/images/8.png"]: Some(["110022345"]) -["./examples/svtr/images/3.png"]: Some(["粤A·68688"]) -["./examples/svtr/images/7.png"]: Some(["Please lower your volume"]) +["./examples/svtr/images/license-ch-2.png"]: Ys([Y { Texts: [Text("粤A·68688")] }]) +["./examples/svtr/images/license-ch.png"]: Ys([Y { Texts: [Text("冀B6G000")] }]) +["./examples/svtr/images/sign-ch-2.png"]: Ys([Y { Texts: [Text("我在南锣鼓捣猫呢")] }]) +["./examples/svtr/images/sign-ch.png"]: Ys([Y { Texts: [Text("小菊儿胡同71号")] }]) +["./examples/svtr/images/text-110022345.png"]: Ys([Y { Texts: [Text("110022345")] }]) +["./examples/svtr/images/text-ch.png"]: Ys([Y { Texts: [Text("你有这么高速运转的机械进入中国,记住我给出的原理")] }]) +["./examples/svtr/images/text-en-2.png"]: Ys([Y { Texts: [Text("from the background, but also separate text instances which")] }]) +["./examples/svtr/images/text-en-dark.png"]: Ys([Y { Texts: [Text("Please lower your volume")] }]) +["./examples/svtr/images/text-en.png"]: Ys([Y { Texts: [Text("are closely jointed. Some examples are illustrated in Fig.7.")] }]) +["./examples/svtr/images/text-hello-rust-handwritten.png"]: Ys([Y { Texts: [Text("HeloRuSt")] }]) + ``` \ No newline at end of file diff --git a/examples/svtr/images/3.png b/examples/svtr/images/license-ch-2.png similarity index 100% rename from examples/svtr/images/3.png rename to examples/svtr/images/license-ch-2.png diff --git a/examples/svtr/images/2.png b/examples/svtr/images/license-ch.png similarity index 100% rename from examples/svtr/images/2.png rename to examples/svtr/images/license-ch.png diff --git a/examples/svtr/images/4.png b/examples/svtr/images/sign-ch-2.png similarity index 100% rename from examples/svtr/images/4.png rename to examples/svtr/images/sign-ch-2.png diff --git a/examples/svtr/images/6.png b/examples/svtr/images/sign-ch.png similarity index 100% rename from examples/svtr/images/6.png rename to examples/svtr/images/sign-ch.png diff --git a/examples/svtr/images/8.png b/examples/svtr/images/text-110022345.png similarity index 100% rename from examples/svtr/images/8.png rename to examples/svtr/images/text-110022345.png diff --git a/examples/svtr/images/1.png b/examples/svtr/images/text-ch.png similarity index 100% rename from examples/svtr/images/1.png rename to examples/svtr/images/text-ch.png diff --git a/examples/svtr/images/9.png b/examples/svtr/images/text-en-2.png similarity index 100% rename from examples/svtr/images/9.png rename to examples/svtr/images/text-en-2.png diff --git a/examples/svtr/images/7.png b/examples/svtr/images/text-en-dark.png similarity index 100% rename from examples/svtr/images/7.png rename to examples/svtr/images/text-en-dark.png diff --git a/examples/svtr/images/5.png b/examples/svtr/images/text-en.png similarity index 100% rename from examples/svtr/images/5.png rename to examples/svtr/images/text-en.png diff --git a/examples/svtr/images/text-hello-rust-handwritten.png b/examples/svtr/images/text-hello-rust-handwritten.png new file mode 100644 index 0000000..750c634 Binary files /dev/null and b/examples/svtr/images/text-hello-rust-handwritten.png differ diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index 43562c1..18704f8 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -1,24 +1,44 @@ +use anyhow::Result; use usls::{models::SVTR, DataLoader, Options}; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + // build model - let options = Options::default() - .with_ixx(0, 0, (1, 2, 8).into()) - .with_ixx(0, 2, (320, 960, 1600).into()) - .with_ixx(0, 3, (320, 960, 1600).into()) - .with_confs(&[0.2]) - .with_vocab("svtr/ppocr_rec_vocab.txt")? - .with_model("svtr/ppocr-v4-svtr-ch-dyn.onnx")?; + let options = Options::ppocr_rec_v4_ch() + // svtr_v2_teacher_ch() + // .with_batch_size(2) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = SVTR::new(options)?; // load images - let dl = DataLoader::new("./examples/svtr/images")?.build()?; + let dl = DataLoader::new("./examples/svtr/images")? + .with_batch(model.batch() as _) + .with_progress_bar(false) + .build()?; // run for (xs, paths) in dl { - let ys = model.run(&xs)?; - println!("{paths:?}: {:?}", ys[0].texts()) + let ys = model.forward(&xs)?; + println!("{paths:?}: {:?}", ys) } + //summary + model.summary(); + Ok(()) } diff --git a/examples/trocr/README.md b/examples/trocr/README.md new file mode 100644 index 0000000..dba262c --- /dev/null +++ b/examples/trocr/README.md @@ -0,0 +1,13 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example trocr -- --device cuda --dtype fp16 --scale s --kind printed + +cargo run -r -F cuda --example trocr -- --device cuda --dtype fp16 --scale s --kind hand-written + +``` + + +```shell +Ys([Y { Texts: [Text("PLEASE LOWER YOUR VOLUME")] }, Y { Texts: [Text("HELLO RUST")] }]) +``` \ No newline at end of file diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs new file mode 100644 index 0000000..3b7d8ea --- /dev/null +++ b/examples/trocr/main.rs @@ -0,0 +1,96 @@ +use usls::{ + models::{TrOCR, TrOCRKind}, + DataLoader, Options, Scale, +}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"s\")")] + scale: String, + + /// kind + #[argh(option, default = "String::from(\"printed\")")] + kind: String, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // load images + let xs = DataLoader::try_read_batch(&[ + "images/text-en-dark.png", + "images/text-hello-rust-handwritten.png", + ])?; + + // build model + let (options_encoder, options_decoder, options_decoder_merged) = + match args.scale.as_str().try_into()? { + Scale::S => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ( + Options::trocr_encoder_small_printed(), + Options::trocr_decoder_small_printed(), + Options::trocr_decoder_merged_small_printed(), + ), + TrOCRKind::HandWritten => ( + Options::trocr_encoder_small_handwritten(), + Options::trocr_decoder_small_handwritten(), + Options::trocr_decoder_merged_small_handwritten(), + ), + }, + Scale::B => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ( + Options::trocr_encoder_base_printed(), + Options::trocr_decoder_base_printed(), + Options::trocr_decoder_merged_base_printed(), + ), + TrOCRKind::HandWritten => ( + Options::trocr_encoder_base_handwritten(), + Options::trocr_decoder_base_handwritten(), + Options::trocr_decoder_merged_base_handwritten(), + ), + }, + x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), + }; + + let mut model = TrOCR::new( + options_encoder + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder_merged + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + )?; + + // inference + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // summary + model.summary(); + + Ok(()) +} diff --git a/examples/viewer/README.md b/examples/viewer/README.md new file mode 100644 index 0000000..0cfe0e0 --- /dev/null +++ b/examples/viewer/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +RUST_LOG=usls=info cargo run -F ffmpeg -r --example viewer +``` diff --git a/examples/viewer/main.rs b/examples/viewer/main.rs new file mode 100644 index 0000000..8279204 --- /dev/null +++ b/examples/viewer/main.rs @@ -0,0 +1,43 @@ +use usls::{DataLoader, Key, Viewer}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// source + #[argh( + option, + default = "String::from(\"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4\")" + )] + source: String, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + let dl = DataLoader::new(&args.source)?.with_batch(1).build()?; + + let mut viewer = Viewer::new().with_delay(5).with_scale(1.).resizable(true); + + // run & annotate + for (xs, _paths) in dl { + // show image + viewer.imshow(&xs)?; + + // check out window and key event + if !viewer.is_open() || viewer.is_key_pressed(Key::Escape) { + break; + } + + // write video + viewer.write_batch(&xs)? + } + + // finish video write + viewer.finish_write()?; + + Ok(()) +} diff --git a/examples/yolo-sam/README.md b/examples/yolo-sam/README.md index 1dfab0c..84dfb0f 100644 --- a/examples/yolo-sam/README.md +++ b/examples/yolo-sam/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r --example yolo-sam +cargo run -r -F cuda --example yolo-sam -- --device cuda ``` ## Results diff --git a/examples/yolo-sam/main.rs b/examples/yolo-sam/main.rs index 3b51ace..b66fb63 100644 --- a/examples/yolo-sam/main.rs +++ b/examples/yolo-sam/main.rs @@ -1,31 +1,42 @@ +use anyhow::Result; use usls::{ - models::{SamKind, SamPrompt, YOLOTask, YOLOVersion, SAM, YOLO}, - Annotator, DataLoader, Options, Vision, + models::{SamPrompt, SAM, YOLO}, + Annotator, DataLoader, Options, Scale, }; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + // build SAM - let options_encoder = Options::default().with_model("sam/mobile-sam-vit-t-encoder.onnx")?; - let options_decoder = Options::default() - .with_find_contours(true) - .with_sam_kind(SamKind::Sam) - .with_model("sam/mobile-sam-vit-t-decoder.onnx")?; + let (options_encoder, options_decoder) = ( + Options::mobile_sam_tiny_encoder().commit()?, + Options::mobile_sam_tiny_decoder().commit()?, + ); let mut sam = SAM::new(options_encoder, options_decoder)?; - // build YOLOv8-Det - let options_yolo = Options::default() - .with_yolo_version(YOLOVersion::V8) - .with_yolo_task(YOLOTask::Detect) - .with_model("yolo/v8-m-dyn.onnx")? - .with_cuda(0) - .with_ixx(0, 2, (416, 640, 800).into()) - .with_ixx(0, 3, (416, 640, 800).into()) - .with_find_contours(false) - .with_confs(&[0.45]); + // build YOLOv8 + let options_yolo = Options::yolo_detect() + .with_model_scale(Scale::N) + .with_model_version(8.0.into()) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut yolo = YOLO::new(options_yolo)?; // load one image - let xs = [DataLoader::try_read("images/dog.jpg")?]; + let xs = DataLoader::try_read_batch(&["images/dog.jpg"])?; // build annotator let annotator = Annotator::default() @@ -36,11 +47,11 @@ fn main() -> Result<(), Box> { .with_saveout("YOLO-SAM"); // run & annotate - let ys_det = yolo.run(&xs)?; - for y_det in ys_det { + let ys_det = yolo.forward(&xs)?; + for y_det in ys_det.iter() { if let Some(bboxes) = y_det.bboxes() { for bbox in bboxes { - let ys_sam = sam.run( + let ys_sam = sam.forward( &xs, &[SamPrompt::default().with_bbox( bbox.xmin(), diff --git a/examples/yolo/README.md b/examples/yolo/README.md index d443f43..5151aa6 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -1,175 +1,65 @@

YOLO-Series

+| Detection | Instance Segmentation | Pose | +| :----------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------: | +| `` | `` | `` | -| Detection | Instance Segmentation | Pose | -| :---------------: | :------------------------: |:---------------: | -| | | | - -| Classification | Obb | -| :------------------------: |:------------------------: | -| | - -| Head Detection | Fall Detection | Trash Detection | -| :------------------------: |:------------------------: |:------------------------: | -| || - -| YOLO-World | Face Parsing | FastSAM | -| :------------------------: |:------------------------: |:------------------------: | -| || - - +| Classification | Obb | +| :----------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: | +| `` | `` | +| Head Detection | Fall Detection | Trash Detection | +| :-----------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: | +| `` | `` | `` | +| YOLO-World | Face Parsing | FastSAM | +| :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | +| `` | `` | `` | ## Quick Start + ```Shell -# customized -cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8 +# Your customized YOLOv8 model +cargo run -r --example yolo -- --task detect --ver v8 --num-classes 6 --model xxx.onnx # YOLOv8 # Classify -cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5 -cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8 -cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11 +cargo run -r --example yolo -- --task classify --ver 5 --scale s --image-width 224 --image-height 224 --num-classes 1000 --use-imagenet-1k-classes # YOLOv5 +cargo run -r --example yolo -- --task classify --ver 8 --scale n --image-width 224 --image-height 224 # YOLOv8 +cargo run -r --example yolo -- --task classify --ver 11 --scale n --image-width 224 --image-height 224 # YOLOv11 # Detect -cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5 -cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6 -cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7 -cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8 -cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9 -cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10 -cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11 -cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR -cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world +cargo run -r --example yolo -- --task detect --ver 5 --scale n --use-coco-80-classes # YOLOv5 +cargo run -r --example yolo -- --task detect --ver 6 --scale n --use-coco-80-classes # YOLOv6 +cargo run -r --example yolo -- --task detect --ver 7 --scale t --use-coco-80-classes # YOLOv7 +cargo run -r --example yolo -- --task detect --ver 8 --scale n --use-coco-80-classes # YOLOv8 +cargo run -r --example yolo -- --task detect --ver 9 --scale t --use-coco-80-classes # YOLOv9 +cargo run -r --example yolo -- --task detect --ver 10 --scale n --use-coco-80-classes # YOLOv10 +cargo run -r --example yolo -- --task detect --ver 11 --scale n --use-coco-80-classes # YOLOv11 +cargo run -r --example yolo -- --task detect --ver 8 --model v8-s-world-v2-shoes.onnx # YOLOv8-world # Pose -cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose -cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose +cargo run -r --example yolo -- --task pose --ver 8 --scale n # YOLOv8-Pose +cargo run -r --example yolo -- --task pose --ver 11 --scale n # YOLOv11-Pose # Segment -cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment -cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM +cargo run -r --example yolo -- --task segment --ver 5 --scale n # YOLOv5-Segment +cargo run -r --example yolo -- --task segment --ver 8 --scale n # YOLOv8-Segment +cargo run -r --example yolo -- --task segment --ver 11 --scale n # YOLOv8-Segment # Obb -cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb -cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb +cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv8-Obb +cargo run -r --example yolo -- --ver 11 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv11-Obb ``` **`cargo run -r --example yolo -- --help` for more options** - -## YOLOs configs with `Options` - -
-Use official YOLO Models - -```Rust -let options = Options::default() - .with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_model("xxxx.onnx")?; - -``` -
- -
-Cutomized your own YOLO model - -```Rust -// This config is for YOLOv8-Segment -use usls::{AnchorsPosition, BoxType, ClssType, YOLOPreds}; - -let options = Options::default() - .with_yolo_preds( - YOLOPreds { - bbox: Some(BoxType::Cxcywh), - clss: ClssType::Clss, - coefs: Some(true), - anchors: Some(AnchorsPosition::After), - ..Default::default() - } - ) - // .with_nc(80) - // .with_names(&COCO_CLASS_NAMES_80) - .with_model("xxxx.onnx")?; -``` -
- ## Other YOLOv8 Solution Models -| Model | Weights | Datasets| -|:---------------------: | :--------------------------: | :-------------------------------: | -| Face-Landmark Detection | [yolov8-face-dyn-f16](https://github.com/jamjamjon/assets/releases/download/yolo/v8-n-face-dyn-f16.onnx) | | -| Head Detection | [yolov8-head-f16](https://github.com/jamjamjon/assets/releases/download/yolo/v8-head-f16.onnx) | | -| Fall Detection | [yolov8-falldown-f16](https://github.com/jamjamjon/assets/releases/download/yolo/v8-falldown-f16.onnx) | | -| Trash Detection | [yolov8-plastic-bag-f16](https://github.com/jamjamjon/assets/releases/download/yolo/v8-plastic-bag-f16.onnx) | | -| FaceParsing | [yolov8-face-parsing-dyn](https://github.com/jamjamjon/assets/releases/download/yolo/v8-face-parsing-dyn.onnx) | [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing)
[[Processed YOLO labels]](https://github.com/jamjamjon/assets/releases/download/yolo/CelebAMask-HQ-YOLO-Labels.zip)[[Python Script]](../../scripts/CelebAMask-HQ-To-YOLO-Labels.py) | - - - - -## Export ONNX Models - - -
-YOLOv5 - -[Here](https://docs.ultralytics.com/yolov5/tutorials/model_export/) - -
- - -
-YOLOv6 - -[Here](https://github.com/meituan/YOLOv6/tree/main/deploy/ONNX) - -
- - -
-YOLOv7 - -[Here](https://github.com/WongKinYiu/yolov7?tab=readme-ov-file#export) - -
- -
-YOLOv8, YOLOv11 - -```Shell -pip install -U ultralytics - -# export onnx model with dynamic shapes -yolo export model=yolov8m.pt format=onnx simplify dynamic -yolo export model=yolov8m-cls.pt format=onnx simplify dynamic -yolo export model=yolov8m-pose.pt format=onnx simplify dynamic -yolo export model=yolov8m-seg.pt format=onnx simplify dynamic -yolo export model=yolov8m-obb.pt format=onnx simplify dynamic - -# export onnx model with fixed shapes -yolo export model=yolov8m.pt format=onnx simplify -yolo export model=yolov8m-cls.pt format=onnx simplify -yolo export model=yolov8m-pose.pt format=onnx simplify -yolo export model=yolov8m-seg.pt format=onnx simplify -yolo export model=yolov8m-obb.pt format=onnx simplify -``` -
- - -
-YOLOv9 - -[Here](https://github.com/WongKinYiu/yolov9/blob/main/export.py) - -
- -
-YOLOv10 - -[Here](https://github.com/THU-MIG/yolov10#export) - -
+| Model | Weights | +| :---------------------: | :------------------------------------------------------: | +| Face-Landmark Detection | [yolov8-n-face](https://github.com/jamjamjon/assets/releases/download/yolo/v8-n-face-fp16.onnx) | +| Head Detection | [yolov8-head](https://github.com/jamjamjon/assets/releases/download/yolo/v8-head-fp16.onnx) | +| Fall Detection | [yolov8-falldown](https://github.com/jamjamjon/assets/releases/download/yolo/v8-falldown-fp16.onnx) | +| Trash Detection | [yolov8-plastic-bag](https://github.com/jamjamjon/assets/releases/download/yolo/v8-plastic-bag-fp16.onnx) | +| FaceParsing | [yolov8-face-parsing-seg](https://github.com/jamjamjon/assets/releases/download/yolo/v8-face-parsing.onnx) | diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 96c51c0..71ec5fb 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,171 +1,213 @@ use anyhow::Result; -use clap::Parser; - use usls::{ - models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask, - YOLOVersion, COCO_SKELETONS_16, + models::YOLO, Annotator, DataLoader, Options, COCO_CLASS_NAMES_80, COCO_SKELETONS_16, + IMAGENET_NAMES_1K, }; -#[derive(Parser, Clone)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Path to the model - #[arg(long)] - pub model: Option, +#[derive(argh::FromArgs, Debug)] +/// Example +struct Args { + /// model file + #[argh(option)] + model: Option, + + /// source + #[argh(option, default = "String::from(\"./assets/bus.jpg\")")] + source: String, + + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, - /// Input source path - #[arg(long, default_value_t = String::from("./assets/bus.jpg"))] - pub source: String, + /// task + #[argh(option, default = "String::from(\"det\")")] + task: String, - /// YOLO Task - #[arg(long, value_enum, default_value_t = YOLOTask::Detect)] - pub task: YOLOTask, + /// version + #[argh(option, default = "8.0")] + ver: f32, - /// YOLO Version - #[arg(long, value_enum, default_value_t = YOLOVersion::V8)] - pub ver: YOLOVersion, + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, - /// YOLO Scale - #[arg(long, value_enum, default_value_t = YOLOScale::N)] - pub scale: YOLOScale, + /// scale + #[argh(option, default = "String::from(\"n\")")] + scale: String, - /// Batch size - #[arg(long, default_value_t = 1)] - pub batch_size: usize, + /// trt_fp16 + #[argh(option, default = "true")] + trt_fp16: bool, - /// Minimum input width - #[arg(long, default_value_t = 224)] - pub width_min: isize, + /// find_contours + #[argh(option, default = "true")] + find_contours: bool, - /// Input width - #[arg(long, default_value_t = 640)] - pub width: isize, + /// batch_size + #[argh(option, default = "1")] + batch_size: usize, - /// Maximum input width - #[arg(long, default_value_t = 1024)] - pub width_max: isize, + /// min_batch_size + #[argh(option, default = "1")] + min_batch_size: usize, - /// Minimum input height - #[arg(long, default_value_t = 224)] - pub height_min: isize, + /// max_batch_size + #[argh(option, default = "4")] + max_batch_size: usize, - /// Input height - #[arg(long, default_value_t = 640)] - pub height: isize, + /// min_image_width + #[argh(option, default = "224")] + min_image_width: isize, - /// Maximum input height - #[arg(long, default_value_t = 1024)] - pub height_max: isize, + /// image_width + #[argh(option, default = "640")] + image_width: isize, - /// Number of classes - #[arg(long, default_value_t = 80)] - pub nc: usize, + /// max_image_width + #[argh(option, default = "1280")] + max_image_width: isize, - /// Class confidence - #[arg(long)] - pub confs: Vec, + /// min_image_height + #[argh(option, default = "224")] + min_image_height: isize, - /// Enable TensorRT support - #[arg(long)] - pub trt: bool, + /// image_height + #[argh(option, default = "640")] + image_height: isize, - /// Enable CUDA support - #[arg(long)] - pub cuda: bool, + /// max_image_height + #[argh(option, default = "1280")] + max_image_height: isize, - /// Enable CoreML support - #[arg(long)] - pub coreml: bool, + /// num_classes + #[argh(option)] + num_classes: Option, - /// Use TensorRT half precision - #[arg(long)] - pub half: bool, + /// num_keypoints + #[argh(option)] + num_keypoints: Option, - /// Device ID to use - #[arg(long, default_value_t = 0)] - pub device_id: usize, + /// use_coco_80_classes + #[argh(switch)] + use_coco_80_classes: bool, - /// Enable performance profiling - #[arg(long)] - pub profile: bool, + /// use_imagenet_1k_classes + #[argh(switch)] + use_imagenet_1k_classes: bool, - /// Disable contour drawing - #[arg(long)] - pub no_contours: bool, + /// confs + #[argh(option)] + confs: Vec, - /// Show result - #[arg(long)] - pub view: bool, + /// keypoint_confs + #[argh(option)] + keypoint_confs: Vec, - /// Do not save output - #[arg(long)] - pub nosave: bool, + /// exclude_classes + #[argh(option)] + exclude_classes: Vec, + + /// retain_classes + #[argh(option)] + retain_classes: Vec, + + /// class_names + #[argh(option)] + class_names: Vec, + + /// keypoint_names + #[argh(option)] + keypoint_names: Vec, } fn main() -> Result<()> { - let args = Args::parse(); - - // model path - let path = match &args.model { - None => format!( - "yolo/{}-{}-{}.onnx", - args.ver.name(), - args.scale.name(), - args.task.name() - ), - Some(x) => x.to_string(), - }; - - // saveout - let saveout = match &args.model { - None => format!( - "{}-{}-{}", - args.ver.name(), - args.scale.name(), - args.task.name() - ), - Some(x) => { - let p = std::path::PathBuf::from(&x); - p.file_stem().unwrap().to_str().unwrap().to_string() - } - }; - - // device - let device = if args.cuda { - Device::Cuda(args.device_id) - } else if args.trt { - Device::Trt(args.device_id) - } else if args.coreml { - Device::CoreML(args.device_id) - } else { - Device::Cpu(args.device_id) - }; - - // build options - let options = Options::new() - .with_model(&path)? - .with_yolo_version(args.ver) - .with_yolo_task(args.task) - .with_device(device) - .with_trt_fp16(args.half) - .with_ixx(0, 0, (1, args.batch_size as _, 4).into()) - .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into()) - .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into()) - .with_confs(if args.confs.is_empty() { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + let mut options = Options::yolo() + .with_model_file(&args.model.unwrap_or_default()) + .with_model_task(args.task.as_str().try_into()?) + .with_model_version(args.ver.into()) + .with_model_scale(args.scale.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_trt_fp16(args.trt_fp16) + .with_model_ixx( + 0, + 0, + (args.min_batch_size, args.batch_size, args.max_batch_size).into(), + ) + .with_model_ixx( + 0, + 2, + ( + args.min_image_height, + args.image_height, + args.max_image_height, + ) + .into(), + ) + .with_model_ixx( + 0, + 3, + (args.min_image_width, args.image_width, args.max_image_width).into(), + ) + .with_class_confs(if args.confs.is_empty() { &[0.2, 0.15] } else { &args.confs }) - .with_nc(args.nc) - // .with_names(&COCO_CLASS_NAMES_80) - // .with_names2(&COCO_KEYPOINTS_17) - .with_find_contours(!args.no_contours) // find contours or not - .exclude_classes(&[0]) - // .retain_classes(&[0, 5]) - .with_profile(args.profile); + .with_keypoint_confs(if args.keypoint_confs.is_empty() { + &[0.5] + } else { + &args.keypoint_confs + }) + .with_find_contours(args.find_contours) + .retain_classes(&args.retain_classes) + .exclude_classes(&args.exclude_classes); + + if args.use_coco_80_classes { + options = options.with_class_names(&COCO_CLASS_NAMES_80); + } + + if args.use_imagenet_1k_classes { + options = options.with_class_names(&IMAGENET_NAMES_1K); + } + + if let Some(nc) = args.num_classes { + options = options.with_nc(nc); + } + + if let Some(nk) = args.num_keypoints { + options = options.with_nk(nk); + } + + if !args.class_names.is_empty() { + options = options.with_class_names( + &args + .class_names + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + } + + if !args.keypoint_names.is_empty() { + options = options.with_keypoint_names( + &args + .keypoint_names + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + } // build model - let mut model = YOLO::new(options)?; + let mut model = YOLO::try_from(options.commit()?)?; // build dataloader let dl = DataLoader::new(&args.source)? @@ -175,56 +217,28 @@ fn main() -> Result<()> { // build annotator let annotator = Annotator::default() .with_skeletons(&COCO_SKELETONS_16) - .without_masks(true) // No masks plotting when doing segment task. + .without_masks(true) .with_bboxes_thickness(3) - .with_keypoints_name(false) // Enable keypoints names - .with_saveout_subs(&["YOLO"]) - .with_saveout(&saveout); - - // build viewer - let mut viewer = if args.view { - Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true)) - } else { - None - }; + .with_saveout(model.spec()); // run & annotate for (xs, _paths) in dl { - // let ys = model.run(&xs)?; // way one - let ys = model.forward(&xs, args.profile)?; // way two - let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?; - - // show image - match &mut viewer { - Some(viewer) => viewer.imshow(&images_plotted)?, - None => continue, - } - - // check out window and key event - match &mut viewer { - Some(viewer) => { - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { - break; - } - } - None => continue, - } - - // write video - if !args.nosave { - match &mut viewer { - Some(viewer) => viewer.write_batch(&images_plotted)?, - None => continue, - } - } + let ys = model.forward(&xs)?; + // extract bboxes + // for y in ys.iter() { + // if let Some(bboxes) = y.bboxes() { + // println!("[Bboxes]: Found {} objects", bboxes.len()); + // for (i, bbox) in bboxes.iter().enumerate() { + // println!("{}: {:?}", i, bbox) + // } + // } + // } + + // plot + annotator.annotate(&xs, &ys); } - // finish video write - if !args.nosave { - if let Some(viewer) = &mut viewer { - viewer.finish_write()?; - } - } + model.summary(); Ok(()) } diff --git a/examples/yolop/main.rs b/examples/yolop/main.rs index 2e338cc..ed8283d 100644 --- a/examples/yolop/main.rs +++ b/examples/yolop/main.rs @@ -1,22 +1,26 @@ +use anyhow::Result; use usls::{models::YOLOPv2, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + // build model - let options = Options::default() - .with_model("yolop/v2-dyn-480x800.onnx")? - .with_confs(&[0.3]); + let options = Options::yolop_v2_480x800().commit()?; let mut model = YOLOPv2::new(options)?; // load image - let x = [DataLoader::try_read("images/car.jpg")?]; + let x = [DataLoader::try_read("images/car-view.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_polygons_name(true) - .with_saveout("YOLOPv2"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/yolov8-rtdetr/README.md b/examples/yolov8-rtdetr/README.md new file mode 100644 index 0000000..78eabd8 --- /dev/null +++ b/examples/yolov8-rtdetr/README.md @@ -0,0 +1,9 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example yolov8-rtdetr -- --device cuda +``` + +```shell +Ys([Y { BBoxes: [Bbox { xyxy: [668.71356, 395.4159, 809.01587, 879.3043], class_id: 0, name: Some("person"), confidence: 0.950527 }, Bbox { xyxy: [48.866394, 399.50665, 248.22641, 904.7525], class_id: 0, name: Some("person"), confidence: 0.9504415 }, Bbox { xyxy: [20.197449, 230.00304, 805.026, 730.3445], class_id: 5, name: Some("bus"), confidence: 0.94705224 }, Bbox { xyxy: [221.3088, 405.65436, 345.44052, 860.2628], class_id: 0, name: Some("person"), confidence: 0.93062377 }, Bbox { xyxy: [0.34117508, 549.8391, 76.50758, 868.87646], class_id: 0, name: Some("person"), confidence: 0.71064234 }, Bbox { xyxy: [282.12543, 484.14166, 296.43207, 520.96246], class_id: 27, name: Some("tie"), confidence: 0.40305245 }] }]) +``` diff --git a/examples/yolov8-rtdetr/main.rs b/examples/yolov8-rtdetr/main.rs new file mode 100644 index 0000000..87f611f --- /dev/null +++ b/examples/yolov8-rtdetr/main.rs @@ -0,0 +1,45 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::yolo_v8_rtdetr_l() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = DataLoader::try_read_batch(&["./assets/bus.jpg"])?; + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index c6e4d7d..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "1.79" diff --git a/scripts/CelebAMask-HQ-To-YOLO-Labels.py b/scripts/CelebAMask-HQ-To-YOLO-Labels.py deleted file mode 100644 index 95babb6..0000000 --- a/scripts/CelebAMask-HQ-To-YOLO-Labels.py +++ /dev/null @@ -1,63 +0,0 @@ -import cv2 -import numpy as np -from pathlib import Path -from tqdm import tqdm - - -mapping = { - 'background': 0, - 'skin': 1, - 'nose': 2, - 'eye_g': 3, - 'l_eye': 4, - 'r_eye': 5, - 'l_brow': 6, - 'r_brow': 7, - 'l_ear': 8, - 'r_ear': 9, - 'mouth': 10, - 'u_lip': 11, - 'l_lip': 12, - 'hair': 13, - 'hat': 14, - 'ear_r': 15, - 'neck_l': 16, - 'neck': 17, - 'cloth': 18 -} - - - -def main(): - saveout_dir = Path("labels") - if not saveout_dir.exists(): - saveout_dir.mkdir() - else: - import shutil - shutil.rmtree(saveout_dir) - saveout_dir.mkdir() - - - image_list = [x for x in Path("CelebAMask-HQ-mask-anno/").rglob("*.png")] - for image_path in tqdm(image_list, total=len(image_list)): - image_gray = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) - stem = image_path.stem - name, cls_ = stem.split("_", 1) - segments = cv2.findContours(image_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] - - saveout = saveout_dir / f"{int(name)}.txt" - with open(saveout, 'a+') as f: - for segment in segments: - line = f"{mapping[cls_]}" - segment = segment / 512 - for seg in segment: - xn, yn = seg[0] - line += f" {xn} {yn}" - f.write(line + "\n") - - - - -if __name__ == "__main__": - main() - diff --git a/scripts/convert2f16.py b/scripts/convert2f16.py deleted file mode 100644 index 6b9eec3..0000000 --- a/scripts/convert2f16.py +++ /dev/null @@ -1,8 +0,0 @@ -import onnx -from pathlib import Path -from onnxconverter_common import float16 - -model_f32 = "onnx_model.onnx" -model_f16 = float16.convert_float_to_float16(onnx.load(model_f32)) -saveout = Path(model_f32).with_name(Path(model_f32).stem + "-f16.onnx") -onnx.save(model_f16, saveout) diff --git a/src/core/device.rs b/src/core/device.rs deleted file mode 100644 index 583df16..0000000 --- a/src/core/device.rs +++ /dev/null @@ -1,14 +0,0 @@ -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub enum Device { - Auto(usize), - Cpu(usize), - Cuda(usize), - Trt(usize), - CoreML(usize), - // Cann(usize), - // Acl(usize), - // Rocm(usize), - // Rknpu(usize), - // Openvino(usize), - // Onednn(usize), -} diff --git a/src/core/hub.rs b/src/core/hub.rs deleted file mode 100644 index c3f3e90..0000000 --- a/src/core/hub.rs +++ /dev/null @@ -1,426 +0,0 @@ -use anyhow::{Context, Result}; -use indicatif::{ProgressBar, ProgressStyle}; -use serde::{Deserialize, Serialize}; -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; - -use crate::Dir; - -/// Represents a downloadable asset in a release -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Asset { - pub name: String, - pub browser_download_url: String, - pub size: u64, -} - -/// Represents a GitHub release -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Release { - pub tag_name: String, - pub assets: Vec, -} - -/// Manages interactions with a GitHub repository's releases -pub struct Hub { - /// github api - _gh_api_release: String, - - /// GitHub repository owner - owner: String, - - /// GitHub repository name - repo: String, - - /// Optional list of releases fetched from GitHub - releases: Option>, - - /// Path to cache file - cache: PathBuf, - - /// Optional release tag to be used - tag: Option, - - /// Filename for the asset, used in cache management - file_name: Option, - file_size: Option, - - /// Full URL constructed for downloading the asset - url: Option, - - /// Local path where the asset will be stored - path: PathBuf, - - /// Directory to store the downloaded file - to: Dir, - - /// Download timeout in seconds - timeout: u64, - - /// Time to live (cache duration) - ttl: std::time::Duration, - - /// Maximum attempts for downloading - max_attempts: u32, -} - -impl std::fmt::Debug for Hub { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Hub") - .field("owner", &self.owner) - .field("repo", &self.repo) - .field("cache", &self.cache) - .field("path", &self.path) - .field("releases", &self.releases.as_ref().map(|x| x.len())) - .field("ttl", &self.ttl) - .field("max_attempts", &self.max_attempts) - .finish() - } -} - -impl Default for Hub { - fn default() -> Self { - let owner = "jamjamjon".to_string(); - let repo = "assets".to_string(); - let _gh_api_release = format!("https://api.github.com/repos/{}/{}/releases", owner, repo); - - Self { - owner, - repo, - _gh_api_release, - url: None, - path: PathBuf::new(), - to: Dir::Cache, - tag: None, - file_name: None, - file_size: None, - releases: None, - cache: PathBuf::new(), - timeout: 3000, - max_attempts: 3, - ttl: std::time::Duration::from_secs(10 * 60), - } - } -} - -impl Hub { - pub fn new() -> Result { - let mut to = Dir::Cache; - let cache = to - .path() - .or_else(|_| { - to = Dir::Home; - to.path() - })? - .join("cache_releases"); - - Ok(Self { - to, - cache, - ..Default::default() - }) - } - - pub fn with_owner(mut self, owner: &str) -> Self { - self.owner = owner.to_string(); - self - } - - pub fn with_repo(mut self, repo: &str) -> Self { - self.repo = repo.to_string(); - self - } - - pub fn with_ttl(mut self, x: u64) -> Self { - self.ttl = std::time::Duration::from_secs(x); - self - } - - pub fn with_timeout(mut self, x: u64) -> Self { - self.timeout = x; - self - } - - pub fn with_max_attempts(mut self, x: u32) -> Self { - self.max_attempts = x; - self - } - - pub fn fetch(mut self, s: &str) -> Result { - // try to fetch from hub or local cache - let p = PathBuf::from(s); - match p.exists() { - true => self.path = p, - false => { - // check remote - match s.split_once('/') { - Some((tag, file_name)) => { - // Extract tag and file from input string - self.tag = Some(tag.to_string()); - self.file_name = Some(file_name.to_string()); - - // Check if releases are already loaded in memory - if self.releases.is_none() { - self.releases = Some(self.connect_remote()?); - } - - if let Some(releases) = &self.releases { - // Validate the tag - let tags: Vec<&str> = - releases.iter().map(|x| x.tag_name.as_str()).collect(); - if !tags.contains(&tag) { - anyhow::bail!( - "Hub tag '{}' not found in releases. Available tags: {:?}", - tag, - tags - ); - } - - // Validate the file - if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { - let files: Vec<&str> = - release.assets.iter().map(|x| x.name.as_str()).collect(); - if !files.contains(&file_name) { - anyhow::bail!( - "Hub file '{}' not found in tag '{}'. Available files: {:?}", - file_name, - tag, - files - ); - } else { - for f_ in release.assets.iter() { - if f_.name.as_str() == file_name { - self.url = Some(f_.browser_download_url.clone()); - self.file_size = Some(f_.size); - - break; - } - } - } - } - self.path = self.to.path_with_subs(&[tag])?.join(file_name); - } - } - _ => anyhow::bail!( - "Download failed due to invalid format. Expected: /, got: {}", - s - ), - } - } - } - - Ok(self) - } - - /// Fetch releases from GitHub and cache them - fn fetch_and_cache_releases(url: &str, cache_path: &Path) -> Result { - let response = ureq::get(url) - .set("User-Agent", "my-app") - .call() - .context("Failed to fetch releases from remote")?; - - if response.status() != 200 { - anyhow::bail!( - "Failed to fetch releases from remote ({}): status {} - {}", - url, - response.status(), - response.status_text() - ); - } - - let body = response - .into_string() - .context("Failed to read response body")?; - - // Ensure cache directory exists - let parent_dir = cache_path - .parent() - .context("Invalid cache path; no parent directory found")?; - std::fs::create_dir_all(parent_dir) - .with_context(|| format!("Failed to create cache directory: {:?}", parent_dir))?; - - // Create temporary file - let mut temp_file = tempfile::NamedTempFile::new_in(parent_dir) - .context("Failed to create temporary cache file")?; - - // Write data to temporary file - temp_file - .write_all(body.as_bytes()) - .context("Failed to write to temporary cache file")?; - - // Persist temporary file as the cache - temp_file.persist(cache_path).with_context(|| { - format!("Failed to persist temporary cache file to {:?}", cache_path) - })?; - - Ok(body) - } - - pub fn tags(&mut self) -> Option> { - if self.releases.is_none() { - self.releases = self.connect_remote().ok(); - } - - self.releases - .as_ref() - .map(|releases| releases.iter().map(|x| x.tag_name.as_str()).collect()) - } - - pub fn files(&mut self, tag: &str) -> Option> { - if self.releases.is_none() { - self.releases = self.connect_remote().ok(); - } - - self.releases.as_ref().map(|releases| { - releases - .iter() - .find(|r| r.tag_name == tag) - .map(|a| a.assets.iter().map(|x| x.name.as_str()).collect()) - })? - } - - pub fn connect_remote(&mut self) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "Hub-connect_remote"); - let _guard = span.enter(); - - let should_download = if !self.cache.exists() { - tracing::info!("No cache found, fetching data from GitHub"); - true - } else { - match std::fs::metadata(&self.cache)?.modified() { - Err(_) => { - tracing::info!("Cannot get file modified time, fetching new data from GitHub"); - true - } - Ok(modified_time) => { - if std::time::SystemTime::now().duration_since(modified_time)? < self.ttl { - tracing::info!("Using cached data"); - false - } else { - tracing::info!("Cache expired, fetching new data from GitHub"); - true - } - } - } - }; - - let body = if should_download { - Self::fetch_and_cache_releases(&self._gh_api_release, &self.cache)? - } else { - std::fs::read_to_string(&self.cache)? - }; - let releases: Vec = serde_json::from_str(&body)?; - Ok(releases) - } - - /// Commit the downloaded file, downloading if necessary - pub fn commit(&self) -> Result { - if let Some(url) = &self.url { - // Download if the file does not exist or if the size of file does not match - if !self.path.is_file() - || self.path.is_file() - && Some(std::fs::metadata(&self.path)?.len()) != self.file_size - { - let name = format!( - "{}/{}", - self.tag.as_ref().unwrap(), - self.file_name.as_ref().unwrap() - ); - Self::download( - url.as_str(), - &self.path, - Some(&name), - Some(self.timeout), - Some(self.max_attempts), - )?; - } - } - self.path - .to_str() - .map(|s| s.to_string()) - .with_context(|| format!("Failed to convert PathBuf: {:?} to String", self.path)) - } - - /// Download a file from a github release to a specified path with a progress bar - pub fn download + std::fmt::Debug>( - src: &str, - dst: P, - prompt: Option<&str>, - timeout: Option, - max_attempts: Option, - ) -> Result<()> { - // TODO: other url, not just github release page - - let max_attempts = max_attempts.unwrap_or(2); - let timeout_duration = std::time::Duration::from_secs(timeout.unwrap_or(2000)); - let agent = ureq::AgentBuilder::new().try_proxy_from_env(true).build(); - - for i_try in 0..max_attempts { - let resp = agent - .get(src) - .timeout(timeout_duration) - .call() - .with_context(|| { - format!( - "Failed to download file from {}, timeout: {:?}", - src, timeout_duration - ) - })?; - let ntotal = resp - .header("Content-Length") - .and_then(|s| s.parse::().ok()) - .context("Content-Length header is missing or invalid")?; - - let pb = ProgressBar::new(ntotal); - pb.set_style( - ProgressStyle::with_template( - "{prefix:.cyan.bold} {msg} |{bar}| ({percent_precise}%, {binary_bytes}/{binary_total_bytes}, {binary_bytes_per_sec})", - )? - .progress_chars("██ "), - ); - pb.set_prefix(if i_try == 0 { - " Fetching" - } else { - " Re-Fetching" - }); - pb.set_message(prompt.unwrap_or_default().to_string()); - - let mut reader = resp.into_reader(); - let mut buffer = [0; 256]; - let mut downloaded_bytes = 0usize; - let mut file = std::fs::File::create(&dst) - .with_context(|| format!("Failed to create destination file: {:?}", dst))?; - - loop { - let bytes_read = reader.read(&mut buffer)?; - if bytes_read == 0 { - break; - } - file.write_all(&buffer[..bytes_read]) - .context("Failed to write to file")?; - downloaded_bytes += bytes_read; - pb.inc(bytes_read as u64); - } - - // check size - if downloaded_bytes as u64 != ntotal { - continue; - } - - // update - pb.set_prefix(" Downloaded"); - pb.set_style(ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH_3, - )?); - pb.finish(); - - if i_try != max_attempts { - break; - } else { - anyhow::bail!("Exceeded the maximum number of download attempts"); - } - } - - Ok(()) - } -} diff --git a/src/core/metric.rs b/src/core/metric.rs deleted file mode 100644 index af0a5ed..0000000 --- a/src/core/metric.rs +++ /dev/null @@ -1,6 +0,0 @@ -#[derive(Debug)] -pub enum Metric { - IP, - Cos, - L2, -} diff --git a/src/core/mod.rs b/src/core/mod.rs deleted file mode 100644 index 0b0c2f1..0000000 --- a/src/core/mod.rs +++ /dev/null @@ -1,45 +0,0 @@ -mod annotator; -mod dataloader; -mod device; -mod dir; -mod dynconf; -mod hub; -mod logits_sampler; -mod media; -mod metric; -mod min_opt_max; -pub mod onnx; -pub mod ops; -mod options; -mod ort_engine; -mod task; -mod tokenizer_stream; -mod ts; -mod viewer; -mod vision; -mod x; -mod xs; - -pub use annotator::Annotator; -pub use dataloader::DataLoader; -pub use device::Device; -pub use dir::Dir; -pub use dynconf::DynConf; -pub use hub::Hub; -pub use logits_sampler::LogitsSampler; -pub use media::*; -pub use metric::Metric; -pub use min_opt_max::MinOptMax; -pub use ops::Ops; -pub use options::Options; -pub use ort_engine::*; -pub use task::Task; -pub use tokenizer_stream::TokenizerStream; -pub use ts::Ts; -pub use viewer::Viewer; -pub use vision::Vision; -pub use x::X; -pub use xs::Xs; - -// re-export -pub use minifb::Key; diff --git a/src/core/options.rs b/src/core/options.rs deleted file mode 100644 index 4e906b5..0000000 --- a/src/core/options.rs +++ /dev/null @@ -1,295 +0,0 @@ -//! Options for build models. - -use anyhow::Result; - -use crate::{ - models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion}, - Device, Hub, Iiix, MinOptMax, Task, -}; - -/// Options for building models -#[derive(Debug, Clone)] -pub struct Options { - pub onnx_path: String, - pub task: Task, - pub device: Device, - pub batch_size: usize, - pub iiixs: Vec, - pub profile: bool, - pub num_dry_run: usize, - - // trt related - pub trt_engine_cache_enable: bool, - pub trt_int8_enable: bool, - pub trt_fp16_enable: bool, - - // options for Vision and Language models - pub nc: Option, - pub nk: Option, - pub nm: Option, - pub confs: Vec, - pub confs2: Vec, - pub confs3: Vec, - pub kconfs: Vec, - pub iou: Option, - pub tokenizer: Option, - pub vocab: Option, - pub context_length: Option, - pub names: Option>, // names - pub names2: Option>, // names2 - pub names3: Option>, // names3 - pub min_width: Option, - pub min_height: Option, - pub unclip_ratio: f32, // DB - pub yolo_task: Option, - pub yolo_version: Option, - pub yolo_preds: Option, - pub find_contours: bool, - pub sam_kind: Option, - pub use_low_res_mask: Option, - pub sapiens_task: Option, - pub classes_excluded: Vec, - pub classes_retained: Vec, -} - -impl Default for Options { - fn default() -> Self { - Self { - onnx_path: String::new(), - device: Device::Cuda(0), - profile: false, - batch_size: 1, - iiixs: vec![], - num_dry_run: 3, - - trt_engine_cache_enable: true, - trt_int8_enable: false, - trt_fp16_enable: false, - nc: None, - nk: None, - nm: None, - confs: vec![0.3f32], - confs2: vec![0.3f32], - confs3: vec![0.3f32], - kconfs: vec![0.5f32], - iou: None, - tokenizer: None, - vocab: None, - context_length: None, - names: None, - names2: None, - names3: None, - min_width: None, - min_height: None, - unclip_ratio: 1.5, - yolo_task: None, - yolo_version: None, - yolo_preds: None, - find_contours: false, - sam_kind: None, - use_low_res_mask: None, - sapiens_task: None, - task: Task::Untitled, - classes_excluded: vec![], - classes_retained: vec![], - } - } -} - -impl Options { - pub fn new() -> Self { - Default::default() - } - - pub fn with_task(mut self, task: Task) -> Self { - self.task = task; - self - } - - pub fn with_model(mut self, onnx_path: &str) -> Result { - self.onnx_path = Hub::new()?.fetch(onnx_path)?.commit()?; - Ok(self) - } - - pub fn with_batch_size(mut self, n: usize) -> Self { - self.batch_size = n; - self - } - - pub fn with_batch(mut self, n: usize) -> Self { - self.batch_size = n; - self - } - - pub fn with_dry_run(mut self, n: usize) -> Self { - self.num_dry_run = n; - self - } - - pub fn with_device(mut self, device: Device) -> Self { - self.device = device; - self - } - - pub fn with_cuda(mut self, id: usize) -> Self { - self.device = Device::Cuda(id); - self - } - - pub fn with_trt(mut self, id: usize) -> Self { - self.device = Device::Trt(id); - self - } - - pub fn with_cpu(mut self) -> Self { - self.device = Device::Cpu(0); - self - } - - pub fn with_coreml(mut self, id: usize) -> Self { - self.device = Device::CoreML(id); - self - } - - pub fn with_trt_fp16(mut self, x: bool) -> Self { - self.trt_fp16_enable = x; - self - } - - pub fn with_yolo_task(mut self, x: YOLOTask) -> Self { - self.yolo_task = Some(x); - self - } - - pub fn with_sapiens_task(mut self, x: SapiensTask) -> Self { - self.sapiens_task = Some(x); - self - } - - pub fn with_yolo_version(mut self, x: YOLOVersion) -> Self { - self.yolo_version = Some(x); - self - } - - pub fn with_profile(mut self, profile: bool) -> Self { - self.profile = profile; - self - } - - pub fn with_find_contours(mut self, x: bool) -> Self { - self.find_contours = x; - self - } - - pub fn with_sam_kind(mut self, x: SamKind) -> Self { - self.sam_kind = Some(x); - self - } - - pub fn use_low_res_mask(mut self, x: bool) -> Self { - self.use_low_res_mask = Some(x); - self - } - - pub fn with_names(mut self, names: &[&str]) -> Self { - self.names = Some(names.iter().map(|x| x.to_string()).collect::>()); - self - } - - pub fn with_names2(mut self, names: &[&str]) -> Self { - self.names2 = Some(names.iter().map(|x| x.to_string()).collect::>()); - self - } - - pub fn with_names3(mut self, names: &[&str]) -> Self { - self.names3 = Some(names.iter().map(|x| x.to_string()).collect::>()); - self - } - - pub fn with_vocab(mut self, vocab: &str) -> Result { - self.vocab = Some(Hub::new()?.fetch(vocab)?.commit()?); - Ok(self) - } - - pub fn with_context_length(mut self, n: usize) -> Self { - self.context_length = Some(n); - self - } - - pub fn with_tokenizer(mut self, tokenizer: &str) -> Result { - self.tokenizer = Some(Hub::new()?.fetch(tokenizer)?.commit()?); - Ok(self) - } - - pub fn with_unclip_ratio(mut self, x: f32) -> Self { - self.unclip_ratio = x; - self - } - - pub fn with_min_width(mut self, x: f32) -> Self { - self.min_width = Some(x); - self - } - - pub fn with_min_height(mut self, x: f32) -> Self { - self.min_height = Some(x); - self - } - - pub fn with_yolo_preds(mut self, x: YOLOPreds) -> Self { - self.yolo_preds = Some(x); - self - } - - pub fn with_nc(mut self, nc: usize) -> Self { - self.nc = Some(nc); - self - } - - pub fn with_nk(mut self, nk: usize) -> Self { - self.nk = Some(nk); - self - } - - pub fn with_iou(mut self, x: f32) -> Self { - self.iou = Some(x); - self - } - - pub fn with_confs(mut self, x: &[f32]) -> Self { - self.confs = x.to_vec(); - self - } - - pub fn with_confs2(mut self, x: &[f32]) -> Self { - self.confs2 = x.to_vec(); - self - } - - pub fn with_confs3(mut self, x: &[f32]) -> Self { - self.confs3 = x.to_vec(); - self - } - - pub fn with_kconfs(mut self, kconfs: &[f32]) -> Self { - self.kconfs = kconfs.to_vec(); - self - } - - pub fn with_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { - self.iiixs.push(Iiix::from((i, ii, x))); - self - } - - pub fn exclude_classes(mut self, xs: &[isize]) -> Self { - self.classes_retained.clear(); - self.classes_excluded.extend_from_slice(xs); - self - } - - pub fn retain_classes(mut self, xs: &[isize]) -> Self { - self.classes_excluded.clear(); - self.classes_retained.extend_from_slice(xs); - self - } -} diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs deleted file mode 100644 index d0b915d..0000000 --- a/src/core/ort_engine.rs +++ /dev/null @@ -1,669 +0,0 @@ -use anyhow::Result; -use half::f16; -use ndarray::{Array, IxDyn}; -use ort::{ - execution_providers::{ExecutionProvider, TensorRTExecutionProvider}, - session::{builder::SessionBuilder, Session}, - tensor::TensorElementType, -}; -use prost::Message; -use std::collections::HashSet; - -use crate::{ - build_progress_bar, human_bytes, onnx, Device, Dir, MinOptMax, Ops, Options, Ts, Xs, - CHECK_MARK, CROSS_MARK, X, -}; - -/// A struct for input composed of the i-th input, the ii-th dimension, and the value. -#[derive(Clone, Debug, Default)] -pub struct Iiix { - pub i: usize, - pub ii: usize, - pub x: MinOptMax, -} - -impl From<(usize, usize, MinOptMax)> for Iiix { - fn from((i, ii, x): (usize, usize, MinOptMax)) -> Self { - Self { i, ii, x } - } -} - -/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. -#[derive(Debug)] -pub struct OrtTensorAttr { - pub names: Vec, - pub dtypes: Vec, - pub dimss: Vec>, -} - -/// ONNXRuntime Backend -#[derive(Debug)] -pub struct OrtEngine { - name: String, - session: Session, - device: Device, - inputs_minoptmax: Vec>, - inputs_attrs: OrtTensorAttr, - outputs_attrs: OrtTensorAttr, - profile: bool, - num_dry_run: usize, - model_proto: onnx::ModelProto, - params: usize, - wbmems: usize, - ts: Ts, -} - -impl OrtEngine { - pub fn new(config: &Options) -> Result { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-new"); - let _guard = span.enter(); - - // onnx graph - let model_proto = Self::load_onnx(&config.onnx_path)?; - let graph = match &model_proto.graph { - Some(graph) => graph, - None => anyhow::bail!("No graph found in this proto. Failed to parse ONNX model."), - }; - - // model params & mems - let byte_alignment = 16; // 16 for simd; 8 for most - let mut params: usize = 0; - let mut wbmems: usize = 0; - let mut initializer_names: HashSet<&str> = HashSet::new(); - for tensor_proto in graph.initializer.iter() { - initializer_names.insert(&tensor_proto.name); - let param = tensor_proto.dims.iter().product::() as usize; - params += param; - - // mems - let param = Ops::make_divisible(param, byte_alignment); - let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize); - let wbmem = param * n; - wbmems += wbmem; - } - - // inputs & outputs - let inputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.input)?; - let outputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.output)?; - let inputs_minoptmax = - Self::build_inputs_minoptmax(&inputs_attrs, &config.iiixs, config.batch_size)?; - - // build - ort::init().commit()?; - let mut builder = Session::builder()?; - let mut device = config.device.to_owned(); - match device { - Device::Trt(device_id) => { - Self::build_trt( - &inputs_attrs.names, - &inputs_minoptmax, - &mut builder, - device_id, - config.trt_int8_enable, - config.trt_fp16_enable, - config.trt_engine_cache_enable, - )?; - } - Device::Cuda(device_id) => { - Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| { - tracing::warn!("{err}, Using cpu"); - device = Device::Cpu(0); - }) - } - Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| { - tracing::warn!("{err}, Using cpu"); - device = Device::Cpu(0); - }), - Device::Cpu(_) => { - Self::build_cpu(&mut builder)?; - } - _ => todo!(), - } - - let session = builder - .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)? - .commit_from_file(&config.onnx_path)?; - - // summary - tracing::info!( - "{CHECK_MARK} Backend: ONNXRuntime | Opset: {} | Device: {:?} | Params: {}", - model_proto.opset_import[0].version, - device, - human_bytes(params as f64), - ); - - Ok(Self { - name: config.onnx_path.to_owned(), - session, - device, - inputs_minoptmax, - inputs_attrs, - outputs_attrs, - profile: config.profile, - num_dry_run: config.num_dry_run, - model_proto, - params, - wbmems, - ts: Ts::default(), - }) - } - - fn build_trt( - names: &[String], - inputs_minoptmax: &[Vec], - builder: &mut SessionBuilder, - device_id: usize, - int8_enable: bool, - fp16_enable: bool, - engine_cache_enable: bool, - ) -> Result<()> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_trt"); - let _guard = span.enter(); - - // auto generate shapes - let mut spec_min = String::new(); - let mut spec_opt = String::new(); - let mut spec_max = String::new(); - for (i, name) in names.iter().enumerate() { - if i != 0 { - spec_min.push(','); - spec_opt.push(','); - spec_max.push(','); - } - let mut s_min = format!("{}:", name); - let mut s_opt = format!("{}:", name); - let mut s_max = format!("{}:", name); - for d in inputs_minoptmax[i].iter() { - let min_ = &format!("{}x", d.min()); - let opt_ = &format!("{}x", d.opt()); - let max_ = &format!("{}x", d.max()); - s_min += min_; - s_opt += opt_; - s_max += max_; - } - s_min.pop(); - s_opt.pop(); - s_max.pop(); - spec_min += &s_min; - spec_opt += &s_opt; - spec_max += &s_max; - } - let p = Dir::Cache.path_with_subs(&["trt-cache"])?; - let trt = TensorRTExecutionProvider::default() - .with_device_id(device_id as i32) - .with_int8(int8_enable) - .with_fp16(fp16_enable) - .with_engine_cache(engine_cache_enable) - .with_engine_cache_path(p.to_str().unwrap()) - .with_timing_cache(false) - .with_profile_min_shapes(spec_min) - .with_profile_opt_shapes(spec_opt) - .with_profile_max_shapes(spec_max); - if trt.is_available()? && trt.register(builder).is_ok() { - tracing::info!("🐢 Initial model serialization with TensorRT may require a wait...\n"); - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} TensorRT initialization failed") - } - } - - fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> { - let ep = ort::execution_providers::CUDAExecutionProvider::default() - .with_device_id(device_id as i32); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CUDA initialization failed") - } - } - - fn build_coreml(builder: &mut SessionBuilder) -> Result<()> { - let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CoreML initialization failed") - } - } - - fn build_cpu(builder: &mut SessionBuilder) -> Result<()> { - let ep = ort::execution_providers::CPUExecutionProvider::default(); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CPU initialization failed") - } - } - - pub fn dry_run(&mut self) -> Result<()> { - if self.num_dry_run > 0 { - // pb - let name = std::path::Path::new(&self.name); - let pb = build_progress_bar( - self.num_dry_run as u64, - " DryRun", - Some( - name.file_name() - .and_then(|x| x.to_str()) - .unwrap_or_default(), - ), - crate::PROGRESS_BAR_STYLE_CYAN_2, - )?; - - // dummy inputs - let mut xs = Vec::new(); - for i in self.inputs_minoptmax.iter() { - let mut x: Vec = Vec::new(); - for i_ in i.iter() { - x.push(i_.opt()); - } - let x: Array = Array::ones(x).into_dyn(); - xs.push(X::from(x)); - } - let xs = Xs::from(xs); - - // run - for _ in 0..self.num_dry_run { - pb.inc(1); - self.run(xs.clone())?; - } - self.ts.clear(); - - // update - let name = std::path::Path::new(&self.name); - pb.set_message(format!( - "{} on {:?}", - name.file_name() - .and_then(|x| x.to_str()) - .unwrap_or_default(), - self.device, - )); - pb.set_style(indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH, - )?); - pb.finish(); - } - Ok(()) - } - - pub fn run(&mut self, xs: Xs) -> Result { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-run"); - let _guard = span.enter(); - - // inputs dtype alignment - let mut xs_ = Vec::new(); - let t_pre = std::time::Instant::now(); - for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { - let x_ = match &idtype { - TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(), - TensorElementType::Float16 => { - ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() - } - TensorElementType::Int32 => { - ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() - } - TensorElementType::Int64 => { - ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() - } - TensorElementType::Uint8 => { - ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() - } - TensorElementType::Int8 => { - ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() - } - TensorElementType::Bool => { - ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() - } - _ => todo!(), - }; - xs_.push(Into::>::into(x_)); - } - let t_pre = t_pre.elapsed(); - self.ts.add_or_push(0, t_pre); - - // inference - let t_run = std::time::Instant::now(); - let outputs = self.session.run(&xs_[..])?; - - let t_run = t_run.elapsed(); - self.ts.add_or_push(1, t_run); - - // oputput - let mut ys = Xs::new(); - let t_post = std::time::Instant::now(); - for (dtype, name) in self - .outputs_attrs - .dtypes - .iter() - .zip(self.outputs_attrs.names.iter()) - { - let y = &outputs[name.as_str()]; - - let y_ = match &dtype { - TensorElementType::Float32 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().into_owned(), - }, - TensorElementType::Float16 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().mapv(f16::to_f32).into_owned(), - }, - TensorElementType::Int64 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().to_owned().mapv(|x| x as f32).into_owned(), - }, - _ => todo!(), - }; - - ys.push_kv(name.as_str(), X::from(y_))?; - } - let t_post = t_post.elapsed(); - self.ts.add_or_push(2, t_post); - - if self.profile { - let len = 10usize; - let n = 4usize; - tracing::info!( - "[Profile] {:>len$.n$?} ({:>len$.n$?} avg) [alignment: {:>len$.n$?} ({:>len$.n$?} avg) | inference: {:>len$.n$?} ({:>len$.n$?} avg) | to_f32: {:>len$.n$?} ({:>len$.n$?} avg)]", - t_pre + t_run + t_post, - self.ts.avg(), - t_pre, - self.ts.avgi(0), - t_run, - self.ts.avgi(1), - t_post, - self.ts.avgi(2), - ); - } - Ok(ys) - } - - fn build_inputs_minoptmax( - inputs_attrs: &OrtTensorAttr, - iiixs: &[Iiix], - batch_size: usize, - ) -> Result>> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_inputs_minoptmax"); - let _guard = span.enter(); - - // init - let mut ys: Vec> = inputs_attrs - .dimss - .iter() - .map(|dims| dims.iter().map(|&x| MinOptMax::from(x)).collect()) - .collect(); - - // update from customized - for iiix in iiixs.iter() { - if let Some(x) = inputs_attrs - .dimss - .get(iiix.i) - .and_then(|dims| dims.get(iiix.ii)) - { - // dynamic - if *x == 0 { - ys[iiix.i][iiix.ii] = iiix.x.clone(); - } - } else { - anyhow::bail!( - "Cannot retrieve the {}-th dimension of the {}-th input.", - iiix.ii, - iiix.i, - ); - } - } - - // deal with the dynamic axis - ys.iter_mut().enumerate().for_each(|(i, xs)| { - xs.iter_mut().enumerate().for_each(|(ii, x)| { - if x.is_dyn() { - let n = if ii == 0 { batch_size } else { 1 }; - let y = MinOptMax::from(n); - tracing::warn!( - "Using dynamic shapes in inputs without specifying it: the {}-th input, the {}-th dimension. \ - Using {:?} by default. You should make it clear when using TensorRT.", - i + 1, ii + 1, y - ); - *x = y; - } - }); - }); - - Ok(ys) - } - - #[allow(dead_code)] - fn nbytes_from_onnx_dtype_id(x: usize) -> usize { - match x { - 7 | 11 | 13 => 8, // i64, f64, u64 - 1 | 6 | 12 => 4, // f32, i32, u32 - 10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16 - 2 | 3 | 9 => 1, // u8, i8, bool - 8 => 4, // string(1~4) - _ => todo!(), - } - } - - #[allow(dead_code)] - fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize { - match x { - ort::tensor::TensorElementType::Float64 - | ort::tensor::TensorElementType::Uint64 - | ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64 - ort::tensor::TensorElementType::Float32 - | ort::tensor::TensorElementType::Uint32 - | ort::tensor::TensorElementType::Int32 - | ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4) - ort::tensor::TensorElementType::Float16 - | ort::tensor::TensorElementType::Bfloat16 - | ort::tensor::TensorElementType::Int16 - | ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 - ort::tensor::TensorElementType::Uint8 - | ort::tensor::TensorElementType::Int8 - | ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool - } - } - - #[allow(dead_code)] - fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { - match value { - 0 => None, - 1 => Some(ort::tensor::TensorElementType::Float32), - 2 => Some(ort::tensor::TensorElementType::Uint8), - 3 => Some(ort::tensor::TensorElementType::Int8), - 4 => Some(ort::tensor::TensorElementType::Uint16), - 5 => Some(ort::tensor::TensorElementType::Int16), - 6 => Some(ort::tensor::TensorElementType::Int32), - 7 => Some(ort::tensor::TensorElementType::Int64), - 8 => Some(ort::tensor::TensorElementType::String), - 9 => Some(ort::tensor::TensorElementType::Bool), - 10 => Some(ort::tensor::TensorElementType::Float16), - 11 => Some(ort::tensor::TensorElementType::Float64), - 12 => Some(ort::tensor::TensorElementType::Uint32), - 13 => Some(ort::tensor::TensorElementType::Uint64), - 14 => None, // COMPLEX64 - 15 => None, // COMPLEX128 - 16 => Some(ort::tensor::TensorElementType::Bfloat16), - _ => None, - } - } - - fn io_from_onnx_value_info( - initializer_names: &HashSet<&str>, - value_info: &[onnx::ValueInfoProto], - ) -> Result { - let mut dimss: Vec> = Vec::new(); - let mut dtypes: Vec = Vec::new(); - let mut names: Vec = Vec::new(); - for v in value_info.iter() { - if initializer_names.contains(v.name.as_str()) { - continue; - } - names.push(v.name.to_string()); - let dtype = match &v.r#type { - Some(dtype) => dtype, - None => continue, - }; - let dtype = match &dtype.value { - Some(dtype) => dtype, - None => continue, - }; - let tensor = match dtype { - onnx::type_proto::Value::TensorType(tensor) => tensor, - _ => continue, - }; - let tensor_type = tensor.elem_type; - let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { - Some(dtype) => dtype, - None => continue, - }; - dtypes.push(tensor_type); - - let shapes = match &tensor.shape { - Some(shapes) => shapes, - None => continue, - }; - let mut shape_: Vec = Vec::new(); - for shape in shapes.dim.iter() { - match &shape.value { - None => continue, - Some(value) => match value { - onnx::tensor_shape_proto::dimension::Value::DimValue(x) => { - shape_.push(*x as _); - } - onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { - shape_.push(0); - } - }, - } - } - dimss.push(shape_); - } - Ok(OrtTensorAttr { - dimss, - dtypes, - names, - }) - } - - pub fn load_onnx>(p: P) -> Result { - let f = std::fs::read(p)?; - Ok(onnx::ModelProto::decode(f.as_slice())?) - } - - pub fn oshapes(&self) -> &Vec> { - &self.outputs_attrs.dimss - } - - pub fn odimss(&self) -> &Vec> { - &self.outputs_attrs.dimss - } - - pub fn onames(&self) -> &Vec { - &self.outputs_attrs.names - } - - pub fn odtypes(&self) -> &Vec { - &self.outputs_attrs.dtypes - } - - pub fn ishapes(&self) -> &Vec> { - &self.inputs_attrs.dimss - } - - pub fn idimss(&self) -> &Vec> { - &self.inputs_attrs.dimss - } - - pub fn inames(&self) -> &Vec { - &self.inputs_attrs.names - } - - pub fn idtypes(&self) -> &Vec { - &self.inputs_attrs.dtypes - } - - pub fn device(&self) -> &Device { - &self.device - } - - pub fn inputs_minoptmax(&self) -> &Vec> { - &self.inputs_minoptmax - } - - pub fn batch(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][0] - } - - pub fn try_height(&self) -> Option<&MinOptMax> { - self.inputs_minoptmax.first().and_then(|x| x.get(2)) - } - - pub fn try_width(&self) -> Option<&MinOptMax> { - self.inputs_minoptmax.first().and_then(|x| x.get(3)) - } - - pub fn height(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][2] - } - - pub fn width(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][3] - } - - pub fn is_batch_dyn(&self) -> bool { - self.ishapes()[0][0] == 0 - } - - pub fn try_fetch(&self, key: &str) -> Option { - match self.session.metadata() { - Err(_) => None, - Ok(metadata) => metadata.custom(key).unwrap_or_default(), - } - } - - pub fn session(&self) -> &Session { - &self.session - } - - pub fn ir_version(&self) -> usize { - self.model_proto.ir_version as usize - } - - pub fn opset_version(&self) -> usize { - self.model_proto.opset_import[0].version as usize - } - - pub fn producer_name(&self) -> String { - self.model_proto.producer_name.to_string() - } - - pub fn producer_version(&self) -> String { - self.model_proto.producer_version.to_string() - } - - pub fn model_version(&self) -> usize { - self.model_proto.model_version as usize - } - - pub fn parameters(&self) -> usize { - self.params - } - - pub fn memory_weights(&self) -> usize { - self.wbmems - } - - pub fn ts(&self) -> &Ts { - &self.ts - } -} diff --git a/src/core/tokenizer_stream.rs b/src/core/tokenizer_stream.rs deleted file mode 100644 index 495d69a..0000000 --- a/src/core/tokenizer_stream.rs +++ /dev/null @@ -1,87 +0,0 @@ -// TODO: refactor -use anyhow::Result; - -/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a -/// streaming way rather than having to wait for the full decoding. -#[derive(Debug)] -pub struct TokenizerStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, -} - -impl TokenizerStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } -} diff --git a/src/core/ts.rs b/src/core/ts.rs deleted file mode 100644 index dc65ae1..0000000 --- a/src/core/ts.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::time::Duration; - -#[derive(Debug, Default)] -pub struct Ts { - n: usize, - ts: Vec, -} - -impl Ts { - pub fn total(&self) -> Duration { - self.ts.iter().sum::() - } - - pub fn n(&self) -> usize { - self.n / self.ts.len() - } - - pub fn avg(&self) -> Duration { - self.total() / self.n() as u32 - } - - pub fn avgi(&self, i: usize) -> Duration { - if i >= self.ts.len() { - panic!("Index out of bound"); - } - self.ts[i] / self.n() as u32 - } - - pub fn ts(&self) -> &Vec { - &self.ts - } - - pub fn add_or_push(&mut self, i: usize, x: Duration) { - match self.ts.get_mut(i) { - Some(elem) => *elem += x, - None => { - if i >= self.ts.len() { - self.ts.push(x) - } - } - } - self.n += 1; - } - - pub fn clear(&mut self) { - self.n = Default::default(); - self.ts = Default::default(); - } -} diff --git a/src/core/vision.rs b/src/core/vision.rs deleted file mode 100644 index f78bc18..0000000 --- a/src/core/vision.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::{Options, Xs, Y}; - -pub trait Vision: Sized { - type Input; // DynamicImage - - /// Creates a new instance of the model with the given options. - fn new(options: Options) -> anyhow::Result; - - /// Preprocesses the input data. - fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result; - - /// Executes the model on the preprocessed data. - fn inference(&mut self, xs: Xs) -> anyhow::Result; - - /// Postprocesses the model's output. - fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> anyhow::Result>; - - /// Executes the full pipeline. - fn run(&mut self, xs: &[Self::Input]) -> anyhow::Result> { - let ys = self.preprocess(xs)?; - let ys = self.inference(ys)?; - let ys = self.postprocess(ys, xs)?; - Ok(ys) - } - - /// Executes the full pipeline. - fn forward(&mut self, xs: &[Self::Input], profile: bool) -> anyhow::Result> { - let span = tracing::span!(tracing::Level::INFO, "Vision-forward"); - let _guard = span.enter(); - - let t_pre = std::time::Instant::now(); - let ys = self.preprocess(xs)?; - let t_pre = t_pre.elapsed(); - - let t_exe = std::time::Instant::now(); - let ys = self.inference(ys)?; - let t_exe = t_exe.elapsed(); - - let t_post = std::time::Instant::now(); - let ys = self.postprocess(ys, xs)?; - let t_post = t_post.elapsed(); - - if profile { - tracing::info!( - "> Preprocess: {t_pre:?} | Execution: {t_exe:?} | Postprocess: {t_post:?}" - ); - } - - Ok(ys) - } -} diff --git a/src/lib.rs b/src/lib.rs index ce9d586..762f800 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,258 +1,49 @@ -//! **usls** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including: +//! **usls** is a Rust library integrated with **ONNXRuntime**, offering a suite of advanced models for **Computer Vision** and **Vision-Language** tasks, including: //! -//! - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10) +//! - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics) //! - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) -//! - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) -//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) +//! - **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone) +//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) +//! - **OCR Models**: [DB](https://arxiv.org/abs/1911.08947), [FAST](https://github.com/czczup/FAST), [SVTR](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) +//! - **And more...** //! -//! # Examples +//! ## ⛳️ Cargo Features //! -//! Refer to [All Runnable Demos](https://github.com/jamjamjon/usls/tree/main/examples) +//! By default, **none of the following features are enabled**. You can enable them as needed: //! -//! # Quick Start +//! - **`auto`**: Automatically downloads prebuilt ONNXRuntime binaries from Pyke’s CDN for supported platforms. //! -//! The following demo shows how to build a `YOLO` with [`Options`], load `image(s)`, `video` and `stream` with [`DataLoader`], and annotate the model's inference results with [`Annotator`]. +//! - If disabled, you'll need to [compile `ONNXRuntime` from source](https://github.com/microsoft/onnxruntime) or [download a precompiled package](https://github.com/microsoft/onnxruntime/releases), and then [link it manually](https://ort.pyke.io/setup/linking). //! -//! ```ignore -//! use usls::{models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; +//!
+//! 👉 For Linux or macOS Users //! -//! fn main() -> anyhow::Result<()> { -//! // Build model with Options -//! let options = Options::new() -//! .with_trt(0) -//! .with_model("yolo/v8-m-dyn.onnx")? -//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR -//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb -//! .with_i00((1, 1, 4).into()) -//! .with_i02((0, 640, 640).into()) -//! .with_i03((0, 640, 640).into()) -//! .with_confs(&[0.2]); -//! let mut model = YOLO::new(options)?; +//! - Download from the [Releases page](https://github.com/microsoft/onnxruntime/releases). +//! - Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable: +//! ```shell +//! export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.20.1 +//! ``` //! -//! // Build DataLoader to load image(s), video, stream -//! let dl = DataLoader::new( -//! "./assets/bus.jpg", // local image -//! // "images/bus.jpg", // remote image -//! // "../set-negs", // local images (from folder) -//! // "../hall.mp4", // local video -//! // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video -//! // "rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream -//! )? -//! .with_batch(3) // iterate with batch_size = 3 -//! .build()?; +//!
+//! - **`ffmpeg`**: Adds support for video streams, real-time frame visualization, and video export. //! -//! // Build annotator -//! let annotator = Annotator::new().with_saveout("YOLO-Demo"); +//! - Powered by [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb). For any issues related to `ffmpeg` features, please refer to the issues of these two crates. +//! - **`cuda`**: Enables the NVIDIA TensorRT provider. +//! - **`trt`**: Enables the NVIDIA TensorRT provider. +//! - **`mps`**: Enables the Apple CoreML provider. //! -//! // Run and Annotate images -//! for (xs, _) in dl { -//! let ys = model.forward(&xs, false)?; -//! annotator.annotate(&xs, &ys); -//! } +//! ## 🎈 Example //! -//! Ok(()) -//! } +//! ```Shell +//! cargo run -r -F cuda --example svtr -- --device cuda //! ``` //! +//! All examples are located in the [examples](https://github.com/jamjamjon/usls/tree/main/examples) directory. -//! # What's More -//! -//! This guide covers the process of using provided models for inference, including how to build a model, load data, annotate results, and retrieve the outputs. Click the sections below to expand for detailed instructions. -//! -//!
-//! Build the Model -//! -//! To build a model, you can use the provided [models] with [Options]: -//! -//! ```ignore -//! use usls::{models::YOLO, Annotator, DataLoader, Options, Vision}; -//! -//! let options = Options::default() -//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR -//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb -//! .with_model("xxxx.onnx")?; -//! let mut model = YOLO::new(options)?; -//! ``` -//! -//! **And there're many options provided by [Options]** -//! -//! - **Choose Execution Provider:** -//! Select `CUDA` (default), `TensorRT`, or `CoreML`: -//! -//! ```ignore -//! let options = Options::default() -//! .with_cuda(0) -//! // .with_trt(0) -//! // .with_coreml(0) -//! // .with_cpu(); -//! ``` -//! -//! - **Dynamic Input Shapes:** -//! Specify dynamic shapes with [MinOptMax]: -//! -//! ```ignore -//! let options = Options::default() -//! .with_i00((1, 2, 4).into()) // batch(min=1, opt=2, max=4) -//! .with_i02((416, 640, 800).into()) // height(min=416, opt=640, max=800) -//! .with_i03((416, 640, 800).into()); // width(min=416, opt=640, max=800) -//! ``` -//! -//! - **Set Confidence Thresholds:** -//! Adjust thresholds for each category: -//! -//! ```ignore -//! let options = Options::default() -//! .with_confs(&[0.4, 0.15]); // class_0: 0.4, others: 0.15 -//! ``` -//! -//! - **Set Class Names:** -//! Provide class names if needed: -//! -//! ```ignore -//! let options = Options::default() -//! .with_names(&COCO_CLASS_NAMES_80); -//! ``` -//! -//! **More options are detailed in the [Options] documentation.** -//! -//! -//!
-//! -//!
-//! Load Images, Video and Stream -//! -//! - **Load a Single Image** -//! Use [DataLoader::try_read] to load an image from a local file or remote source: -//! -//! ```ignore -//! let x = DataLoader::try_read("./assets/bus.jpg")?; // from local -//! let x = DataLoader::try_read("images/bus.jpg")?; // from remote -//! ``` -//! -//! Alternatively, use [image::ImageReader] directly: -//! -//! ```ignore -//! let x = image::ImageReader::open("myimage.png")?.decode()?; -//! ``` -//! -//! - **Load Multiple Images, Videos, or Streams** -//! Create a [DataLoader] instance for batch processing: -//! -//! ```ignore -//! let dl = DataLoader::new( -//! "./assets/bus.jpg", // local image -//! // "images/bus.jpg", // remote image -//! // "../set-negs", // local images (from folder) -//! // "../hall.mp4", // local video -//! // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video -//! // "rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream -//! )? -//! .with_batch(3) // iterate with batch_size = 3 -//! .build()?; -//! -//! // Iterate through the data -//! for (xs, _) in dl {} -//! ``` -//! -//! - **Convert Images to Video** -//! Use [DataLoader::is2v] to create a video from a sequence of images: -//! -//! ```ignore -//! let fps = 24; -//! let image_folder = "runs/YOLO-DataLoader"; -//! let saveout = ["runs", "is2v"]; -//! DataLoader::is2v(image_folder, &saveout, fps)?; -//! ``` -//! -//!
-//! -//!
-//! Annotate Inference Results -//! -//! - **Create an Annotator Instance** -//! -//! ```ignore -//! let annotator = Annotator::default(); -//! ``` -//! -//! - **Set Saveout Name:** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .with_saveout("YOLOs"); -//! ``` -//! -//! - **Set Bounding Box Line Width:** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .with_bboxes_thickness(4); -//! ``` -//! -//! - **Disable Mask Plotting** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .without_masks(true); -//! ``` -//! -//! - **Perform Inference and nnotate the results** -//! -//! ```ignore -//! for (xs, _paths) in dl { -//! let ys = model.run(&xs)?; -//! annotator.annotate(&xs, &ys); -//! } -//! ``` -//! -//! More options are detailed in the [Annotator] documentation. -//! -//!
-//! -//!
-//! Retrieve Model's Inference Results -//! -//! Retrieve the inference outputs, which are saved in a [`Vec`]: -//! -//! - **Get Detection Bounding Boxes** -//! -//! ```ignore -//! let ys = model.run(&xs)?; -//! for y in ys { -//! // bboxes -//! if let Some(bboxes) = y.bboxes() { -//! for bbox in bboxes { -//! println!( -//! "Bbox: {}, {}, {}, {}, {}, {}", -//! bbox.xmin(), -//! bbox.ymin(), -//! bbox.xmax(), -//! bbox.ymax(), -//! bbox.confidence(), -//! bbox.id(), -//! ); -//! } -//! } -//! } -//! ``` -//! -//!
-//! -//!
-//! Custom Model Implementation -//! -//! You can also implement your own model using [OrtEngine] and [Options]. [OrtEngine] supports ONNX model loading, metadata parsing, dry_run, inference, and other functions, with execution providers such as CUDA, TensorRT, CoreML, etc. -//! -//! For more details, refer to the [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs). -//! -//!
- -mod core; +mod misc; pub mod models; -mod utils; -mod ys; +mod xy; -pub use core::*; +pub use misc::*; pub use models::*; -pub use utils::*; -pub use ys::*; +pub use xy::*; diff --git a/src/core/annotator.rs b/src/misc/annotator.rs similarity index 89% rename from src/core/annotator.rs rename to src/misc/annotator.rs index 694cb7e..7d34749 100644 --- a/src/core/annotator.rs +++ b/src/misc/annotator.rs @@ -1,16 +1,15 @@ -use crate::{ - colormap256, string_now, Bbox, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, CHECK_MARK, - CROSS_MARK, Y, -}; use ab_glyph::{FontArc, PxScale}; use anyhow::Result; use image::{DynamicImage, GenericImage, Rgba, RgbaImage}; use imageproc::map::map_colors; +use crate::{ + string_now, Bbox, Color, ColorMap256, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, Y, +}; + /// Annotator for struct `Y` #[derive(Clone)] pub struct Annotator { - // TODO: Add lifetime font: FontArc, _scale: f32, // Cope with ab_glyph & imageproc=0.24.0 scale_dy: f32, @@ -18,6 +17,7 @@ pub struct Annotator { saveout: Option, saveout_subs: Vec, decimal_places: usize, + palette: Vec, // About mbrs without_mbrs: bool, @@ -57,7 +57,7 @@ pub struct Annotator { // About masks without_masks: bool, - colormap: Option<[[u8; 3]; 256]>, + colormap: Option<[Color; 256]>, // About probs probs_topk: usize, @@ -73,6 +73,7 @@ impl Default for Annotator { _scale: 6.666667, scale_dy: 28., polygons_alpha: 179, + palette: Color::palette_base_20(), saveout: None, saveout_subs: vec![], saveout_base: String::from("runs"), @@ -272,22 +273,8 @@ impl Annotator { } pub fn with_colormap(mut self, x: &str) -> Self { - let x = match x { - "turbo" | "Turbo" | "TURBO" => colormap256::TURBO, - "inferno" | "Inferno" | "INFERNO" => colormap256::INFERNO, - "plasma" | "Plasma" | "PLASMA" => colormap256::PLASMA, - "viridis" | "Viridis" | "VIRIDIS" => colormap256::VIRIDIS, - "magma" | "Magma" | "MAGMA" => colormap256::MAGMA, - "bentcoolwarm" | "BentCoolWarm" | "BENTCOOLWARM" => colormap256::BENTCOOLWARM, - "blackbody" | "BlackBody" | "BLACKBODY" => colormap256::BLACKBODY, - "extendedkindLmann" | "ExtendedKindLmann" | "EXTENDEDKINDLMANN" => { - colormap256::EXTENDEDKINDLMANN - } - "kindlmann" | "KindLmann" | "KINDLMANN" => colormap256::KINDLMANN, - "smoothcoolwarm" | "SmoothCoolWarm" | "SMOOTHCOOLWARM" => colormap256::SMOOTHCOOLWARM, - _ => todo!(), - }; - self.colormap = Some(x); + let x = ColorMap256::from(x); + self.colormap = Some(x.data()); self } @@ -355,7 +342,7 @@ impl Annotator { } // mkdir even no filename specified - Dir::Currnet.raw_path_with_subs(&subs) + Dir::Current.raw_path_with_subs(&subs) } /// Annotate images, save, and no return @@ -365,9 +352,6 @@ impl Annotator { /// Plot images and return plotted images pub fn plot(&self, imgs: &[DynamicImage], ys: &[Y], save: bool) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "Annotator-plot"); - let _guard = span.enter(); - let mut vs: Vec = Vec::new(); // annotate @@ -418,9 +402,9 @@ impl Annotator { if save { let saveout = self.saveout()?.join(format!("{}.png", string_now("-"))); match img_rgba.save(&saveout) { - Err(err) => tracing::error!("{} Saving failed: {:?}", CROSS_MARK, err), + Err(err) => anyhow::bail!("Failed to save annotated image: {:?}", err), Ok(_) => { - tracing::info!("{} Annotated image saved to: {:?}", CHECK_MARK, saveout); + println!("Annotated image saved to: {:?}", saveout); } } } @@ -428,6 +412,7 @@ impl Annotator { // RgbaImage -> DynamicImage vs.push(image::DynamicImage::from(img_rgba)); } + Ok(vs) } @@ -673,7 +658,7 @@ impl Annotator { let luma = if let Some(colormap) = self.colormap { let luma = map_colors(mask.mask(), |p| { let x = p[0]; - image::Rgb(colormap[x as usize]) + image::Rgb(colormap[x as usize].rgb().into()) }); image::DynamicImage::from(luma) } else { @@ -774,42 +759,15 @@ impl Annotator { /// Load custom font fn load_font(path: Option<&str>) -> Result { let path_font = match path { - None => Hub::new()?.fetch("fonts/Arial.ttf")?.commit()?, + None => Hub::default().try_fetch("fonts/Arial.ttf")?, Some(p) => p.into(), }; - let buffer = std::fs::read(path_font)?; - Ok(FontArc::try_from_vec(buffer.to_owned())?) + let buf = std::fs::read(path_font)?; + Ok(FontArc::try_from_vec(buf.to_owned())?) } - /// Pick color from pallette + /// Color palette pub fn get_color(&self, n: usize) -> (u8, u8, u8, u8) { - Self::color_palette()[n % Self::color_palette().len()] - } - - /// Color pallette - fn color_palette() -> [(u8, u8, u8, u8); 20] { - // TODO: more colors - [ - (0, 255, 127, 255), // spring green - (255, 105, 180, 255), // hot pink - (255, 99, 71, 255), // tomato - (255, 215, 0, 255), // glod - (188, 143, 143, 255), // rosy brown - (0, 191, 255, 255), // deep sky blue - (143, 188, 143, 255), // dark sea green - (238, 130, 238, 255), // violet - (154, 205, 50, 255), // yellow green - (205, 133, 63, 255), // peru - (30, 144, 255, 255), // dodger blue - (112, 128, 144, 255), // slate gray - (127, 255, 212, 255), // aqua marine - (51, 153, 255, 255), // blue - (0, 255, 255, 255), // cyan - (138, 43, 226, 255), // blue violet - (165, 42, 42, 255), // brown - (216, 191, 216, 255), // thistle - (240, 255, 255, 255), // azure - (95, 158, 160, 255), // cadet blue - ] + self.palette[n % self.palette.len()].rgba() } } diff --git a/src/misc/color.rs b/src/misc/color.rs new file mode 100644 index 0000000..aef39d8 --- /dev/null +++ b/src/misc/color.rs @@ -0,0 +1,171 @@ +use anyhow::Result; +use rand::Rng; + +/// Color: 0xRRGGBBAA +#[derive(Copy, Clone)] +pub struct Color(u32); + +impl std::fmt::Debug for Color { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Color") + .field("RGBA", &self.rgba()) + .field("HEX", &self.hex()) + .finish() + } +} + +impl std::fmt::Display for Color { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.hex()) + } +} + +impl From for Color { + fn from(x: u32) -> Self { + Self(x) + } +} + +impl From<(u8, u8, u8)> for Color { + fn from((r, g, b): (u8, u8, u8)) -> Self { + Self::from_rgba(r, g, b, 0xff) + } +} + +impl From<[u8; 3]> for Color { + fn from(c: [u8; 3]) -> Self { + Self::from((c[0], c[1], c[2])) + } +} + +impl From<(u8, u8, u8, u8)> for Color { + fn from((r, g, b, a): (u8, u8, u8, u8)) -> Self { + Self::from_rgba(r, g, b, a) + } +} + +impl From<[u8; 4]> for Color { + fn from(c: [u8; 4]) -> Self { + Self::from((c[0], c[1], c[2], c[3])) + } +} + +impl TryFrom<&str> for Color { + type Error = &'static str; + + fn try_from(x: &str) -> Result { + let hex = x.trim_start_matches('#'); + let hex = match hex.len() { + 6 => format!("{}ff", hex), + 8 => hex.to_string(), + _ => return Err("Failed to convert `Color` from str: invalid length"), + }; + + u32::from_str_radix(&hex, 16) + .map(Self) + .map_err(|_| "Failed to convert `Color` from str: invalid hex") + } +} + +impl Color { + const fn from_rgba(r: u8, g: u8, b: u8, a: u8) -> Self { + Self(((r as u32) << 24) | ((g as u32) << 16) | ((b as u32) << 8) | (a as u32)) + } + + pub fn rgba(&self) -> (u8, u8, u8, u8) { + let r = ((self.0 >> 24) & 0xff) as u8; + let g = ((self.0 >> 16) & 0xff) as u8; + let b = ((self.0 >> 8) & 0xff) as u8; + let a = (self.0 & 0xff) as u8; + (r, g, b, a) + } + + pub fn rgb(&self) -> (u8, u8, u8) { + let (r, g, b, _) = self.rgba(); + (r, g, b) + } + + pub fn bgr(&self) -> (u8, u8, u8) { + let (r, g, b) = self.rgb(); + (b, g, r) + } + + pub fn hex(&self) -> String { + format!("#{:08x}", self.0) + } + + pub fn create_palette + Copy>(xs: &[A]) -> Vec { + xs.iter().copied().map(Into::into).collect() + } + + pub fn try_create_palette + Copy>(xs: &[A]) -> Result> + where + >::Error: std::fmt::Debug, + { + xs.iter() + .copied() + .map(|x| { + x.try_into() + .map_err(|e| anyhow::anyhow!("Failed to convert: {:?}", e)) + }) + .collect() + } + + pub fn palette_rand(n: usize) -> Vec { + let mut rng = rand::thread_rng(); + let xs: Vec<(u8, u8, u8)> = (0..n) + .map(|_| { + ( + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ) + }) + .collect(); + + Self::create_palette(&xs) + } + + pub fn palette_base_20() -> Vec { + Self::create_palette(&[ + 0x00ff7fff, // SpringGreen + 0xff69b4ff, // HotPink + 0xff6347ff, // Tomato + 0xffd700ff, // Gold + 0xbc8f8fff, // RosyBrown + 0x00bfffff, // DeepSkyBlue + 0x8fb88fff, // DarkSeaGreen + 0xee82eeff, // Violet + 0x9acd32ff, // YellowGreen + 0xcd853fff, // Peru + 0x1e90ffff, // DodgerBlue + 0xd74a49ff, // ? + 0x7fffd4ff, // AquaMarine + 0x3399ffff, // Blue2 + 0x00ffffff, // Cyan + 0x8a2befff, // BlueViolet + 0xa52a2aff, // Brown + 0xd8bfd8ff, // Thistle + 0xf0ffffff, // Azure + 0x609ea0ff, // CadetBlue + ]) + } + + pub fn palette_cotton_candy_5() -> Vec { + Self::try_create_palette(&["#ff595e", "#ffca3a", "#8ac926", "#1982c4", "#6a4c93"]) + .expect("Faild to create palette: Cotton Candy") + } + + pub fn palette_tropical_sunrise_5() -> Vec { + // https://colorkit.co/palette/e12729-f37324-f8cc1b-72b043-007f4e/ + Self::try_create_palette(&["#e12729", "#f37324", "#f8cc1b", "#72b043", "#007f4e"]) + .expect("Faild to create palette: Tropical Sunrise") + } + + pub fn palette_rainbow_10() -> Vec { + Self::create_palette(&[ + 0xff595eff, 0xff924cff, 0xffca3aff, 0xc5ca30ff, 0x8ac926ff, 0x52a675ff, 0x1982c4ff, + 0x4267acff, 0x6a4c93ff, 0xb5a6c9ff, + ]) + } +} diff --git a/src/misc/colormap256.rs b/src/misc/colormap256.rs new file mode 100644 index 0000000..6a64a1d --- /dev/null +++ b/src/misc/colormap256.rs @@ -0,0 +1,435 @@ +use crate::Color; + +pub enum ColorMap256 { + Turbo, + Inferno, + Plasma, + Viridis, + Magma, + BentCoolWarm, + BlackBody, + ExtendedKindLmann, + KindLmann, + SmoothCoolWarm, +} + +impl From<&str> for ColorMap256 { + fn from(s: &str) -> Self { + match s { + "turbo" | "Turbo" | "TURBO" => Self::Turbo, + "inferno" | "Inferno" | "INFERNO" => Self::Inferno, + "plasma" | "Plasma" | "PLASMA" => Self::Plasma, + "viridis" | "Viridis" | "VIRIDIS" => Self::Viridis, + "magma" | "Magma" | "MAGMA" => Self::Magma, + "bentcoolwarm" | "BentCoolWarm" | "BENTCOOLWARM" => Self::BentCoolWarm, + "blackbody" | "BlackBody" | "BLACKBODY" => Self::BlackBody, + "extendedkindLmann" | "ExtendedKindLmann" | "EXTENDEDKINDLMANN" => { + Self::ExtendedKindLmann + } + "kindlmann" | "KindLmann" | "KINDLMANN" => Self::KindLmann, + "smoothcoolwarm" | "SmoothCoolWarm" | "SMOOTHCOOLWARM" => Self::SmoothCoolWarm, + _ => todo!(), + } + } +} + +impl ColorMap256 { + pub fn data(&self) -> [Color; 256] { + let xs = match self { + Self::Turbo => [ + 0x30123bff, 0x321543ff, 0x33184aff, 0x341b51ff, 0x351e58ff, 0x36215fff, 0x372466ff, + 0x38276dff, 0x392a73ff, 0x3a2d79ff, 0x3b2f80ff, 0x3c3286ff, 0x3d358bff, 0x3e3891ff, + 0x3f3b97ff, 0x3f3e9cff, 0x4040a2ff, 0x4143a7ff, 0x4146acff, 0x4249b1ff, 0x424bb5ff, + 0x434ebaff, 0x4451bfff, 0x4454c3ff, 0x4456c7ff, 0x4559cbff, 0x455ccfff, 0x455ed3ff, + 0x4661d6ff, 0x4664daff, 0x4666ddff, 0x4669e0ff, 0x466be3ff, 0x476ee6ff, 0x4771e9ff, + 0x4773ebff, 0x4776eeff, 0x4778f0ff, 0x477bf2ff, 0x467df4ff, 0x4680f6ff, 0x4682f8ff, + 0x4685faff, 0x4687fbff, 0x458afcff, 0x458cfdff, 0x448ffeff, 0x4391feff, 0x4294ffff, + 0x4196ffff, 0x4099ffff, 0x3e9bfeff, 0x3d9efeff, 0x3ba0fdff, 0x3aa3fcff, 0x38a5fbff, + 0x37a8faff, 0x35abf8ff, 0x33adf7ff, 0x31aff5ff, 0x2fb2f4ff, 0x2eb4f2ff, 0x2cb7f0ff, + 0x2ab9eeff, 0x28bcebff, 0x27bee9ff, 0x25c0e7ff, 0x23c3e4ff, 0x22c5e2ff, 0x20c7dfff, + 0x1fc9ddff, 0x1ecbdaff, 0x1ccdd8ff, 0x1bd0d5ff, 0x1ad2d2ff, 0x1ad4d0ff, 0x19d5cdff, + 0x18d7caff, 0x18d9c8ff, 0x18dbc5ff, 0x18ddc2ff, 0x18dec0ff, 0x18e0bdff, 0x19e2bbff, + 0x19e3b9ff, 0x1ae4b6ff, 0x1ce6b4ff, 0x1de7b2ff, 0x1fe9afff, 0x20eaacff, 0x22ebaaff, + 0x25eca7ff, 0x27eea4ff, 0x2aefa1ff, 0x2cf09eff, 0x2ff19bff, 0x32f298ff, 0x35f394ff, + 0x38f491ff, 0x3cf58eff, 0x3ff68aff, 0x43f787ff, 0x46f884ff, 0x4af880ff, 0x4ef97dff, + 0x52fa7aff, 0x55fa76ff, 0x59fb73ff, 0x5dfc6fff, 0x61fc6cff, 0x65fd69ff, 0x69fd66ff, + 0x6dfe62ff, 0x71fe5fff, 0x75fe5cff, 0x79fe59ff, 0x7dff56ff, 0x80ff53ff, 0x84ff51ff, + 0x88ff4eff, 0x8bff4bff, 0x8fff49ff, 0x92ff47ff, 0x96fe44ff, 0x99fe42ff, 0x9cfe40ff, + 0x9ffd3fff, 0xa1fd3dff, 0xa4fc3cff, 0xa7fc3aff, 0xa9fb39ff, 0xacfb38ff, 0xaffa37ff, + 0xb1f936ff, 0xb4f836ff, 0xb7f735ff, 0xb9f635ff, 0xbcf534ff, 0xbef434ff, 0xc1f334ff, + 0xc3f134ff, 0xc6f034ff, 0xc8ef34ff, 0xcbed34ff, 0xcdec34ff, 0xd0ea34ff, 0xd2e935ff, + 0xd4e735ff, 0xd7e535ff, 0xd9e436ff, 0xdbe236ff, 0xdde037ff, 0xdfdf37ff, 0xe1dd37ff, + 0xe3db38ff, 0xe5d938ff, 0xe7d739ff, 0xe9d539ff, 0xebd339ff, 0xecd13aff, 0xeecf3aff, + 0xefcd3aff, 0xf1cb3aff, 0xf2c93aff, 0xf4c73aff, 0xf5c53aff, 0xf6c33aff, 0xf7c13aff, + 0xf8be39ff, 0xf9bc39ff, 0xfaba39ff, 0xfbb838ff, 0xfbb637ff, 0xfcb336ff, 0xfcb136ff, + 0xfdae35ff, 0xfdac34ff, 0xfea933ff, 0xfea732ff, 0xfea431ff, 0xfea130ff, 0xfe9e2fff, + 0xfe9b2dff, 0xfe992cff, 0xfe962bff, 0xfe932aff, 0xfe9029ff, 0xfd8d27ff, 0xfd8a26ff, + 0xfc8725ff, 0xfc8423ff, 0xfb8122ff, 0xfb7e21ff, 0xfa7b1fff, 0xf9781eff, 0xf9751dff, + 0xf8721cff, 0xf76f1aff, 0xf66c19ff, 0xf56918ff, 0xf46617ff, 0xf36315ff, 0xf26014ff, + 0xf15d13ff, 0xf05b12ff, 0xef5811ff, 0xed5510ff, 0xec530fff, 0xeb500eff, 0xea4e0dff, + 0xe84b0cff, 0xe7490cff, 0xe5470bff, 0xe4450aff, 0xe2430aff, 0xe14109ff, 0xdf3f08ff, + 0xdd3d08ff, 0xdc3b07ff, 0xda3907ff, 0xd83706ff, 0xd63506ff, 0xd43305ff, 0xd23105ff, + 0xd02f05ff, 0xce2d04ff, 0xcc2b04ff, 0xca2a04ff, 0xc82803ff, 0xc52603ff, 0xc32503ff, + 0xc12302ff, 0xbe2102ff, 0xbc2002ff, 0xb91e02ff, 0xb71d02ff, 0xb41b01ff, 0xb21a01ff, + 0xaf1801ff, 0xac1701ff, 0xa91601ff, 0xa71401ff, 0xa41301ff, 0xa11201ff, 0x9e1001ff, + 0x9b0f01ff, 0x980e01ff, 0x950d01ff, 0x920b01ff, 0x8e0a01ff, 0x8b0902ff, 0x880802ff, + 0x850702ff, 0x810602ff, 0x7e0502ff, 0x7a0403ff, + ], + Self::Inferno => [ + 0x000004ff, 0x010005ff, 0x010106ff, 0x010108ff, 0x02010aff, 0x02020cff, 0x02020eff, + 0x030210ff, 0x040312ff, 0x040314ff, 0x050417ff, 0x060419ff, 0x07051bff, 0x08051dff, + 0x09061fff, 0x0a0722ff, 0x0b0724ff, 0x0c0826ff, 0x0d0829ff, 0x0e092bff, 0x10092dff, + 0x110a30ff, 0x120a32ff, 0x140b34ff, 0x150b37ff, 0x160b39ff, 0x180c3cff, 0x190c3eff, + 0x1b0c41ff, 0x1c0c43ff, 0x1e0c45ff, 0x1f0c48ff, 0x210c4aff, 0x230c4cff, 0x240c4fff, + 0x260c51ff, 0x280b53ff, 0x290b55ff, 0x2b0b57ff, 0x2d0b59ff, 0x2f0a5bff, 0x310a5cff, + 0x320a5eff, 0x340a5fff, 0x360961ff, 0x380962ff, 0x390963ff, 0x3b0964ff, 0x3d0965ff, + 0x3e0966ff, 0x400a67ff, 0x420a68ff, 0x440a68ff, 0x450a69ff, 0x470b6aff, 0x490b6aff, + 0x4a0c6bff, 0x4c0c6bff, 0x4d0d6cff, 0x4f0d6cff, 0x510e6cff, 0x520e6dff, 0x540f6dff, + 0x550f6dff, 0x57106eff, 0x59106eff, 0x5a116eff, 0x5c126eff, 0x5d126eff, 0x5f136eff, + 0x61136eff, 0x62146eff, 0x64156eff, 0x65156eff, 0x67166eff, 0x69166eff, 0x6a176eff, + 0x6c186eff, 0x6d186eff, 0x6f196eff, 0x71196eff, 0x721a6eff, 0x741a6eff, 0x751b6eff, + 0x771c6dff, 0x781c6dff, 0x7a1d6dff, 0x7c1d6dff, 0x7d1e6dff, 0x7f1e6cff, 0x801f6cff, + 0x82206cff, 0x84206bff, 0x85216bff, 0x87216bff, 0x88226aff, 0x8a226aff, 0x8c2369ff, + 0x8d2369ff, 0x8f2469ff, 0x902568ff, 0x922568ff, 0x932667ff, 0x952667ff, 0x972766ff, + 0x982766ff, 0x9a2865ff, 0x9b2964ff, 0x9d2964ff, 0x9f2a63ff, 0xa02a63ff, 0xa22b62ff, + 0xa32c61ff, 0xa52c60ff, 0xa62d60ff, 0xa82e5fff, 0xa92e5eff, 0xab2f5eff, 0xad305dff, + 0xae305cff, 0xb0315bff, 0xb1325aff, 0xb3325aff, 0xb43359ff, 0xb63458ff, 0xb73557ff, + 0xb93556ff, 0xba3655ff, 0xbc3754ff, 0xbd3853ff, 0xbf3952ff, 0xc03a51ff, 0xc13a50ff, + 0xc33b4fff, 0xc43c4eff, 0xc63d4dff, 0xc73e4cff, 0xc83f4bff, 0xca404aff, 0xcb4149ff, + 0xcc4248ff, 0xce4347ff, 0xcf4446ff, 0xd04545ff, 0xd24644ff, 0xd34743ff, 0xd44842ff, + 0xd54a41ff, 0xd74b3fff, 0xd84c3eff, 0xd94d3dff, 0xda4e3cff, 0xdb503bff, 0xdd513aff, + 0xde5238ff, 0xdf5337ff, 0xe05536ff, 0xe15635ff, 0xe25734ff, 0xe35933ff, 0xe45a31ff, + 0xe55c30ff, 0xe65d2fff, 0xe75e2eff, 0xe8602dff, 0xe9612bff, 0xea632aff, 0xeb6429ff, + 0xeb6628ff, 0xec6726ff, 0xed6925ff, 0xee6a24ff, 0xef6c23ff, 0xef6e21ff, 0xf06f20ff, + 0xf1711fff, 0xf1731dff, 0xf2741cff, 0xf3761bff, 0xf37819ff, 0xf47918ff, 0xf57b17ff, + 0xf57d15ff, 0xf67e14ff, 0xf68013ff, 0xf78212ff, 0xf78410ff, 0xf8850fff, 0xf8870eff, + 0xf8890cff, 0xf98b0bff, 0xf98c0aff, 0xf98e09ff, 0xfa9008ff, 0xfa9207ff, 0xfa9407ff, + 0xfb9606ff, 0xfb9706ff, 0xfb9906ff, 0xfb9b06ff, 0xfb9d07ff, 0xfc9f07ff, 0xfca108ff, + 0xfca309ff, 0xfca50aff, 0xfca60cff, 0xfca80dff, 0xfcaa0fff, 0xfcac11ff, 0xfcae12ff, + 0xfcb014ff, 0xfcb216ff, 0xfcb418ff, 0xfbb61aff, 0xfbb81dff, 0xfbba1fff, 0xfbbc21ff, + 0xfbbe23ff, 0xfac026ff, 0xfac228ff, 0xfac42aff, 0xfac62dff, 0xf9c72fff, 0xf9c932ff, + 0xf9cb35ff, 0xf8cd37ff, 0xf8cf3aff, 0xf7d13dff, 0xf7d340ff, 0xf6d543ff, 0xf6d746ff, + 0xf5d949ff, 0xf5db4cff, 0xf4dd4fff, 0xf4df53ff, 0xf4e156ff, 0xf3e35aff, 0xf3e55dff, + 0xf2e661ff, 0xf2e865ff, 0xf2ea69ff, 0xf1ec6dff, 0xf1ed71ff, 0xf1ef75ff, 0xf1f179ff, + 0xf2f27dff, 0xf2f482ff, 0xf3f586ff, 0xf3f68aff, 0xf4f88eff, 0xf5f992ff, 0xf6fa96ff, + 0xf8fb9aff, 0xf9fc9dff, 0xfafda1ff, 0xfcffa4ff, + ], + Self::Plasma => [ + 0x0d0887ff, 0x100788ff, 0x130789ff, 0x16078aff, 0x19068cff, 0x1b068dff, 0x1d068eff, + 0x20068fff, 0x220690ff, 0x240691ff, 0x260591ff, 0x280592ff, 0x2a0593ff, 0x2c0594ff, + 0x2e0595ff, 0x2f0596ff, 0x310597ff, 0x330597ff, 0x350498ff, 0x370499ff, 0x38049aff, + 0x3a049aff, 0x3c049bff, 0x3e049cff, 0x3f049cff, 0x41049dff, 0x43039eff, 0x44039eff, + 0x46039fff, 0x48039fff, 0x4903a0ff, 0x4b03a1ff, 0x4c02a1ff, 0x4e02a2ff, 0x5002a2ff, + 0x5102a3ff, 0x5302a3ff, 0x5502a4ff, 0x5601a4ff, 0x5801a4ff, 0x5901a5ff, 0x5b01a5ff, + 0x5c01a6ff, 0x5e01a6ff, 0x6001a6ff, 0x6100a7ff, 0x6300a7ff, 0x6400a7ff, 0x6600a7ff, + 0x6700a8ff, 0x6900a8ff, 0x6a00a8ff, 0x6c00a8ff, 0x6e00a8ff, 0x6f00a8ff, 0x7100a8ff, + 0x7201a8ff, 0x7401a8ff, 0x7501a8ff, 0x7701a8ff, 0x7801a8ff, 0x7a02a8ff, 0x7b02a8ff, + 0x7d03a8ff, 0x7e03a8ff, 0x8004a8ff, 0x8104a7ff, 0x8305a7ff, 0x8405a7ff, 0x8606a6ff, + 0x8707a6ff, 0x8808a6ff, 0x8a09a5ff, 0x8b0aa5ff, 0x8d0ba5ff, 0x8e0ca4ff, 0x8f0da4ff, + 0x910ea3ff, 0x920fa3ff, 0x9410a2ff, 0x9511a1ff, 0x9613a1ff, 0x9814a0ff, 0x99159fff, + 0x9a169fff, 0x9c179eff, 0x9d189dff, 0x9e199dff, 0xa01a9cff, 0xa11b9bff, 0xa21d9aff, + 0xa31e9aff, 0xa51f99ff, 0xa62098ff, 0xa72197ff, 0xa82296ff, 0xaa2395ff, 0xab2494ff, + 0xac2694ff, 0xad2793ff, 0xae2892ff, 0xb02991ff, 0xb12a90ff, 0xb22b8fff, 0xb32c8eff, + 0xb42e8dff, 0xb52f8cff, 0xb6308bff, 0xb7318aff, 0xb83289ff, 0xba3388ff, 0xbb3488ff, + 0xbc3587ff, 0xbd3786ff, 0xbe3885ff, 0xbf3984ff, 0xc03a83ff, 0xc13b82ff, 0xc23c81ff, + 0xc33d80ff, 0xc43e7fff, 0xc5407eff, 0xc6417dff, 0xc7427cff, 0xc8437bff, 0xc9447aff, + 0xca457aff, 0xcb4679ff, 0xcc4778ff, 0xcc4977ff, 0xcd4a76ff, 0xce4b75ff, 0xcf4c74ff, + 0xd04d73ff, 0xd14e72ff, 0xd24f71ff, 0xd35171ff, 0xd45270ff, 0xd5536fff, 0xd5546eff, + 0xd6556dff, 0xd7566cff, 0xd8576bff, 0xd9586aff, 0xda5a6aff, 0xda5b69ff, 0xdb5c68ff, + 0xdc5d67ff, 0xdd5e66ff, 0xde5f65ff, 0xde6164ff, 0xdf6263ff, 0xe06363ff, 0xe16462ff, + 0xe26561ff, 0xe26660ff, 0xe3685fff, 0xe4695eff, 0xe56a5dff, 0xe56b5dff, 0xe66c5cff, + 0xe76e5bff, 0xe76f5aff, 0xe87059ff, 0xe97158ff, 0xe97257ff, 0xea7457ff, 0xeb7556ff, + 0xeb7655ff, 0xec7754ff, 0xed7953ff, 0xed7a52ff, 0xee7b51ff, 0xef7c51ff, 0xef7e50ff, + 0xf07f4fff, 0xf0804eff, 0xf1814dff, 0xf1834cff, 0xf2844bff, 0xf3854bff, 0xf3874aff, + 0xf48849ff, 0xf48948ff, 0xf58b47ff, 0xf58c46ff, 0xf68d45ff, 0xf68f44ff, 0xf79044ff, + 0xf79143ff, 0xf79342ff, 0xf89441ff, 0xf89540ff, 0xf9973fff, 0xf9983eff, 0xf99a3eff, + 0xfa9b3dff, 0xfa9c3cff, 0xfa9e3bff, 0xfb9f3aff, 0xfba139ff, 0xfba238ff, 0xfca338ff, + 0xfca537ff, 0xfca636ff, 0xfca835ff, 0xfca934ff, 0xfdab33ff, 0xfdac33ff, 0xfdae32ff, + 0xfdaf31ff, 0xfdb130ff, 0xfdb22fff, 0xfdb42fff, 0xfdb52eff, 0xfeb72dff, 0xfeb82cff, + 0xfeba2cff, 0xfebb2bff, 0xfebd2aff, 0xfebe2aff, 0xfec029ff, 0xfdc229ff, 0xfdc328ff, + 0xfdc527ff, 0xfdc627ff, 0xfdc827ff, 0xfdca26ff, 0xfdcb26ff, 0xfccd25ff, 0xfcce25ff, + 0xfcd025ff, 0xfcd225ff, 0xfbd324ff, 0xfbd524ff, 0xfbd724ff, 0xfad824ff, 0xfada24ff, + 0xf9dc24ff, 0xf9dd25ff, 0xf8df25ff, 0xf8e125ff, 0xf7e225ff, 0xf7e425ff, 0xf6e626ff, + 0xf6e826ff, 0xf5e926ff, 0xf5eb27ff, 0xf4ed27ff, 0xf3ee27ff, 0xf3f027ff, 0xf2f227ff, + 0xf1f426ff, 0xf1f525ff, 0xf0f724ff, 0xf0f921ff, + ], + Self::Viridis => [ + 0x440154ff, 0x440256ff, 0x450457ff, 0x450559ff, 0x46075aff, 0x46085cff, 0x460a5dff, + 0x460b5eff, 0x470d60ff, 0x470e61ff, 0x471063ff, 0x471164ff, 0x471365ff, 0x481467ff, + 0x481668ff, 0x481769ff, 0x48186aff, 0x481a6cff, 0x481b6dff, 0x481c6eff, 0x481d6fff, + 0x481f70ff, 0x482071ff, 0x482173ff, 0x482374ff, 0x482475ff, 0x482576ff, 0x482677ff, + 0x482878ff, 0x482979ff, 0x472a7aff, 0x472c7aff, 0x472d7bff, 0x472e7cff, 0x472f7dff, + 0x46307eff, 0x46327eff, 0x46337fff, 0x463480ff, 0x453581ff, 0x453781ff, 0x453882ff, + 0x443983ff, 0x443a83ff, 0x443b84ff, 0x433d84ff, 0x433e85ff, 0x423f85ff, 0x424086ff, + 0x424186ff, 0x414287ff, 0x414487ff, 0x404588ff, 0x404688ff, 0x3f4788ff, 0x3f4889ff, + 0x3e4989ff, 0x3e4a89ff, 0x3e4c8aff, 0x3d4d8aff, 0x3d4e8aff, 0x3c4f8aff, 0x3c508bff, + 0x3b518bff, 0x3b528bff, 0x3a538bff, 0x3a548cff, 0x39558cff, 0x39568cff, 0x38588cff, + 0x38598cff, 0x375a8cff, 0x375b8dff, 0x365c8dff, 0x365d8dff, 0x355e8dff, 0x355f8dff, + 0x34608dff, 0x34618dff, 0x33628dff, 0x33638dff, 0x32648eff, 0x32658eff, 0x31668eff, + 0x31678eff, 0x31688eff, 0x30698eff, 0x306a8eff, 0x2f6b8eff, 0x2f6c8eff, 0x2e6d8eff, + 0x2e6e8eff, 0x2e6f8eff, 0x2d708eff, 0x2d718eff, 0x2c718eff, 0x2c728eff, 0x2c738eff, + 0x2b748eff, 0x2b758eff, 0x2a768eff, 0x2a778eff, 0x2a788eff, 0x29798eff, 0x297a8eff, + 0x297b8eff, 0x287c8eff, 0x287d8eff, 0x277e8eff, 0x277f8eff, 0x27808eff, 0x26818eff, + 0x26828eff, 0x26828eff, 0x25838eff, 0x25848eff, 0x25858eff, 0x24868eff, 0x24878eff, + 0x23888eff, 0x23898eff, 0x238a8dff, 0x228b8dff, 0x228c8dff, 0x228d8dff, 0x218e8dff, + 0x218f8dff, 0x21908dff, 0x21918cff, 0x20928cff, 0x20928cff, 0x20938cff, 0x1f948cff, + 0x1f958bff, 0x1f968bff, 0x1f978bff, 0x1f988bff, 0x1f998aff, 0x1f9a8aff, 0x1e9b8aff, + 0x1e9c89ff, 0x1e9d89ff, 0x1f9e89ff, 0x1f9f88ff, 0x1fa088ff, 0x1fa188ff, 0x1fa187ff, + 0x1fa287ff, 0x20a386ff, 0x20a486ff, 0x21a585ff, 0x21a685ff, 0x22a785ff, 0x22a884ff, + 0x23a983ff, 0x24aa83ff, 0x25ab82ff, 0x25ac82ff, 0x26ad81ff, 0x27ad81ff, 0x28ae80ff, + 0x29af7fff, 0x2ab07fff, 0x2cb17eff, 0x2db27dff, 0x2eb37cff, 0x2fb47cff, 0x31b57bff, + 0x32b67aff, 0x34b679ff, 0x35b779ff, 0x37b878ff, 0x38b977ff, 0x3aba76ff, 0x3bbb75ff, + 0x3dbc74ff, 0x3fbc73ff, 0x40bd72ff, 0x42be71ff, 0x44bf70ff, 0x46c06fff, 0x48c16eff, + 0x4ac16dff, 0x4cc26cff, 0x4ec36bff, 0x50c46aff, 0x52c569ff, 0x54c568ff, 0x56c667ff, + 0x58c765ff, 0x5ac864ff, 0x5cc863ff, 0x5ec962ff, 0x60ca60ff, 0x63cb5fff, 0x65cb5eff, + 0x67cc5cff, 0x69cd5bff, 0x6ccd5aff, 0x6ece58ff, 0x70cf57ff, 0x73d056ff, 0x75d054ff, + 0x77d153ff, 0x7ad151ff, 0x7cd250ff, 0x7fd34eff, 0x81d34dff, 0x84d44bff, 0x86d549ff, + 0x89d548ff, 0x8bd646ff, 0x8ed645ff, 0x90d743ff, 0x93d741ff, 0x95d840ff, 0x98d83eff, + 0x9bd93cff, 0x9dd93bff, 0xa0da39ff, 0xa2da37ff, 0xa5db36ff, 0xa8db34ff, 0xaadc32ff, + 0xaddc30ff, 0xb0dd2fff, 0xb2dd2dff, 0xb5de2bff, 0xb8de29ff, 0xbade28ff, 0xbddf26ff, + 0xc0df25ff, 0xc2df23ff, 0xc5e021ff, 0xc8e020ff, 0xcae11fff, 0xcde11dff, 0xd0e11cff, + 0xd2e21bff, 0xd5e21aff, 0xd8e219ff, 0xdae319ff, 0xdde318ff, 0xdfe318ff, 0xe2e418ff, + 0xe5e419ff, 0xe7e419ff, 0xeae51aff, 0xece51bff, 0xefe51cff, 0xf1e51dff, 0xf4e61eff, + 0xf6e620ff, 0xf8e621ff, 0xfbe723ff, 0xfde725ff, + ], + Self::Magma => [ + 0x000004ff, 0x010005ff, 0x010106ff, 0x010108ff, 0x020109ff, 0x02020bff, 0x02020dff, + 0x03030fff, 0x030312ff, 0x040414ff, 0x050416ff, 0x060518ff, 0x06051aff, 0x07061cff, + 0x08071eff, 0x090720ff, 0x0a0822ff, 0x0b0924ff, 0x0c0926ff, 0x0d0a29ff, 0x0e0b2bff, + 0x100b2dff, 0x110c2fff, 0x120d31ff, 0x130d34ff, 0x140e36ff, 0x150e38ff, 0x160f3bff, + 0x180f3dff, 0x19103fff, 0x1a1042ff, 0x1c1044ff, 0x1d1147ff, 0x1e1149ff, 0x20114bff, + 0x21114eff, 0x221150ff, 0x241253ff, 0x251255ff, 0x271258ff, 0x29115aff, 0x2a115cff, + 0x2c115fff, 0x2d1161ff, 0x2f1163ff, 0x311165ff, 0x331067ff, 0x341069ff, 0x36106bff, + 0x38106cff, 0x390f6eff, 0x3b0f70ff, 0x3d0f71ff, 0x3f0f72ff, 0x400f74ff, 0x420f75ff, + 0x440f76ff, 0x451077ff, 0x471078ff, 0x491078ff, 0x4a1079ff, 0x4c117aff, 0x4e117bff, + 0x4f127bff, 0x51127cff, 0x52137cff, 0x54137dff, 0x56147dff, 0x57157eff, 0x59157eff, + 0x5a167eff, 0x5c167fff, 0x5d177fff, 0x5f187fff, 0x601880ff, 0x621980ff, 0x641a80ff, + 0x651a80ff, 0x671b80ff, 0x681c81ff, 0x6a1c81ff, 0x6b1d81ff, 0x6d1d81ff, 0x6e1e81ff, + 0x701f81ff, 0x721f81ff, 0x732081ff, 0x752181ff, 0x762181ff, 0x782281ff, 0x792282ff, + 0x7b2382ff, 0x7c2382ff, 0x7e2482ff, 0x802582ff, 0x812581ff, 0x832681ff, 0x842681ff, + 0x862781ff, 0x882781ff, 0x892881ff, 0x8b2981ff, 0x8c2981ff, 0x8e2a81ff, 0x902a81ff, + 0x912b81ff, 0x932b80ff, 0x942c80ff, 0x962c80ff, 0x982d80ff, 0x992d80ff, 0x9b2e7fff, + 0x9c2e7fff, 0x9e2f7fff, 0xa02f7fff, 0xa1307eff, 0xa3307eff, 0xa5317eff, 0xa6317dff, + 0xa8327dff, 0xaa337dff, 0xab337cff, 0xad347cff, 0xae347bff, 0xb0357bff, 0xb2357bff, + 0xb3367aff, 0xb5367aff, 0xb73779ff, 0xb83779ff, 0xba3878ff, 0xbc3978ff, 0xbd3977ff, + 0xbf3a77ff, 0xc03a76ff, 0xc23b75ff, 0xc43c75ff, 0xc53c74ff, 0xc73d73ff, 0xc83e73ff, + 0xca3e72ff, 0xcc3f71ff, 0xcd4071ff, 0xcf4070ff, 0xd0416fff, 0xd2426fff, 0xd3436eff, + 0xd5446dff, 0xd6456cff, 0xd8456cff, 0xd9466bff, 0xdb476aff, 0xdc4869ff, 0xde4968ff, + 0xdf4a68ff, 0xe04c67ff, 0xe24d66ff, 0xe34e65ff, 0xe44f64ff, 0xe55064ff, 0xe75263ff, + 0xe85362ff, 0xe95462ff, 0xea5661ff, 0xeb5760ff, 0xec5860ff, 0xed5a5fff, 0xee5b5eff, + 0xef5d5eff, 0xf05f5eff, 0xf1605dff, 0xf2625dff, 0xf2645cff, 0xf3655cff, 0xf4675cff, + 0xf4695cff, 0xf56b5cff, 0xf66c5cff, 0xf66e5cff, 0xf7705cff, 0xf7725cff, 0xf8745cff, + 0xf8765cff, 0xf9785dff, 0xf9795dff, 0xf97b5dff, 0xfa7d5eff, 0xfa7f5eff, 0xfa815fff, + 0xfb835fff, 0xfb8560ff, 0xfb8761ff, 0xfc8961ff, 0xfc8a62ff, 0xfc8c63ff, 0xfc8e64ff, + 0xfc9065ff, 0xfd9266ff, 0xfd9467ff, 0xfd9668ff, 0xfd9869ff, 0xfd9a6aff, 0xfd9b6bff, + 0xfe9d6cff, 0xfe9f6dff, 0xfea16eff, 0xfea36fff, 0xfea571ff, 0xfea772ff, 0xfea973ff, + 0xfeaa74ff, 0xfeac76ff, 0xfeae77ff, 0xfeb078ff, 0xfeb27aff, 0xfeb47bff, 0xfeb67cff, + 0xfeb77eff, 0xfeb97fff, 0xfebb81ff, 0xfebd82ff, 0xfebf84ff, 0xfec185ff, 0xfec287ff, + 0xfec488ff, 0xfec68aff, 0xfec88cff, 0xfeca8dff, 0xfecc8fff, 0xfecd90ff, 0xfecf92ff, + 0xfed194ff, 0xfed395ff, 0xfed597ff, 0xfed799ff, 0xfed89aff, 0xfdda9cff, 0xfddc9eff, + 0xfddea0ff, 0xfde0a1ff, 0xfde2a3ff, 0xfde3a5ff, 0xfde5a7ff, 0xfde7a9ff, 0xfde9aaff, + 0xfdebacff, 0xfcecaeff, 0xfceeb0ff, 0xfcf0b2ff, 0xfcf2b4ff, 0xfcf4b6ff, 0xfcf6b8ff, + 0xfcf7b9ff, 0xfcf9bbff, 0xfcfbbdff, 0xfcfdbfff, + ], + Self::BentCoolWarm => [ + 0x3b4cc0ff, 0x3c4ec1ff, 0x3d4fc2ff, 0x3e50c2ff, 0x3f52c3ff, 0x4053c4ff, 0x4155c4ff, + 0x4256c5ff, 0x4357c5ff, 0x4459c6ff, 0x455ac7ff, 0x465bc7ff, 0x475dc8ff, 0x485ec8ff, + 0x495fc9ff, 0x4a61caff, 0x4b62caff, 0x4c64cbff, 0x4d65cbff, 0x4e66ccff, 0x4f68ccff, + 0x5169cdff, 0x526aceff, 0x536cceff, 0x546dcfff, 0x556ecfff, 0x5670d0ff, 0x5771d0ff, + 0x5972d1ff, 0x5a74d1ff, 0x5b75d2ff, 0x5c76d2ff, 0x5d78d3ff, 0x5f79d3ff, 0x607ad4ff, + 0x617cd4ff, 0x627dd5ff, 0x647ed5ff, 0x6580d6ff, 0x6681d6ff, 0x6782d6ff, 0x6983d7ff, + 0x6a85d7ff, 0x6b86d8ff, 0x6d87d8ff, 0x6e89d9ff, 0x6f8ad9ff, 0x718bdaff, 0x728ddaff, + 0x738edaff, 0x758fdbff, 0x7691dbff, 0x7792dcff, 0x7993dcff, 0x7a95dcff, 0x7b96ddff, + 0x7d97ddff, 0x7e98deff, 0x809adeff, 0x819bdeff, 0x829cdfff, 0x849edfff, 0x859fdfff, + 0x87a0e0ff, 0x88a2e0ff, 0x8aa3e1ff, 0x8ba4e1ff, 0x8ca5e1ff, 0x8ea7e2ff, 0x8fa8e2ff, + 0x91a9e2ff, 0x92abe3ff, 0x94ace3ff, 0x95ade3ff, 0x97afe4ff, 0x98b0e4ff, 0x9ab1e4ff, + 0x9bb2e5ff, 0x9db4e5ff, 0x9fb5e5ff, 0xa0b6e6ff, 0xa2b8e6ff, 0xa3b9e6ff, 0xa5bae6ff, + 0xa6bbe7ff, 0xa8bde7ff, 0xaabee7ff, 0xabbfe8ff, 0xadc1e8ff, 0xaec2e8ff, 0xb0c3e8ff, + 0xb2c4e9ff, 0xb3c6e9ff, 0xb5c7e9ff, 0xb7c8eaff, 0xb8caeaff, 0xbacbeaff, 0xbccceaff, + 0xbdcdebff, 0xbfcfebff, 0xc1d0ebff, 0xc2d1ecff, 0xc4d2ecff, 0xc6d4ecff, 0xc8d5ecff, + 0xc9d6edff, 0xcbd7edff, 0xcdd9edff, 0xcfdaedff, 0xd0dbeeff, 0xd2dceeff, 0xd4deeeff, + 0xd6dfeeff, 0xd7e0efff, 0xd9e1efff, 0xdbe3efff, 0xdde4efff, 0xdfe5f0ff, 0xe1e6f0ff, + 0xe2e8f0ff, 0xe4e9f0ff, 0xe6eaf1ff, 0xe8ebf1ff, 0xeaedf1ff, 0xeceef1ff, 0xeeeff2ff, + 0xeff0f2ff, 0xf1f2f2ff, 0xf2f1f1ff, 0xf2f0efff, 0xf1eeedff, 0xf1edebff, 0xf1ebe8ff, + 0xf1eae6ff, 0xf0e8e4ff, 0xf0e7e2ff, 0xf0e5e0ff, 0xefe4deff, 0xefe2dbff, 0xefe1d9ff, + 0xeedfd7ff, 0xeeded5ff, 0xeedcd3ff, 0xeddbd1ff, 0xedd9cfff, 0xedd8cdff, 0xecd6cbff, + 0xecd5c9ff, 0xecd3c7ff, 0xebd2c4ff, 0xebd0c2ff, 0xebcfc0ff, 0xeacdbeff, 0xeaccbcff, + 0xe9cabaff, 0xe9c9b8ff, 0xe9c7b6ff, 0xe8c5b4ff, 0xe8c4b3ff, 0xe8c2b1ff, 0xe7c1afff, + 0xe7bfadff, 0xe6beabff, 0xe6bca9ff, 0xe6bba7ff, 0xe5b9a5ff, 0xe5b8a3ff, 0xe4b6a1ff, + 0xe4b59fff, 0xe4b39eff, 0xe3b19cff, 0xe3b09aff, 0xe2ae98ff, 0xe2ad96ff, 0xe2ab94ff, + 0xe1aa93ff, 0xe1a891ff, 0xe0a78fff, 0xe0a58dff, 0xdfa38cff, 0xdfa28aff, 0xdfa088ff, + 0xde9f86ff, 0xde9d85ff, 0xdd9c83ff, 0xdd9a81ff, 0xdc9880ff, 0xdc977eff, 0xdb957cff, + 0xdb947bff, 0xda9279ff, 0xda9077ff, 0xd98f76ff, 0xd98d74ff, 0xd98c72ff, 0xd88a71ff, + 0xd8886fff, 0xd7876eff, 0xd7856cff, 0xd6846bff, 0xd68269ff, 0xd58068ff, 0xd47f66ff, + 0xd47d65ff, 0xd37b63ff, 0xd37a62ff, 0xd27860ff, 0xd2775fff, 0xd1755dff, 0xd1735cff, + 0xd0725aff, 0xd07059ff, 0xcf6e58ff, 0xcf6c56ff, 0xce6b55ff, 0xcd6953ff, 0xcd6752ff, + 0xcc6651ff, 0xcc644fff, 0xcb624eff, 0xcb604dff, 0xca5f4bff, 0xc95d4aff, 0xc95b49ff, + 0xc85948ff, 0xc85746ff, 0xc75645ff, 0xc65444ff, 0xc65243ff, 0xc55041ff, 0xc54e40ff, + 0xc44c3fff, 0xc34a3eff, 0xc3483dff, 0xc2463cff, 0xc1443aff, 0xc14239ff, 0xc04038ff, + 0xc03e37ff, 0xbf3c36ff, 0xbe3a35ff, 0xbe3734ff, 0xbd3533ff, 0xbc3232ff, 0xbc3031ff, + 0xbb2d30ff, 0xba2b2fff, 0xba282eff, 0xb9252dff, 0xb8222cff, 0xb81e2bff, 0xb71b2aff, + 0xb61629ff, 0xb51228ff, 0xb50c27ff, 0xb40426ff, + ], + Self::BlackBody => [ + 0x000000ff, 0x030101ff, 0x070201ff, 0x0a0302ff, 0x0d0402ff, 0x100503ff, 0x120603ff, + 0x140704ff, 0x160804ff, 0x180905ff, 0x1a0a05ff, 0x1b0b06ff, 0x1d0b06ff, 0x1e0c07ff, + 0x200d08ff, 0x210e08ff, 0x220f09ff, 0x240f09ff, 0x25100aff, 0x26100aff, 0x28110bff, + 0x29110bff, 0x2b120cff, 0x2c120cff, 0x2e120dff, 0x2f130dff, 0x31130eff, 0x32130eff, + 0x34140fff, 0x36140fff, 0x37140fff, 0x391510ff, 0x3a1510ff, 0x3c1510ff, 0x3e1611ff, + 0x3f1611ff, 0x411611ff, 0x421712ff, 0x441712ff, 0x461712ff, 0x471813ff, 0x491813ff, + 0x4b1813ff, 0x4c1914ff, 0x4e1914ff, 0x501914ff, 0x511914ff, 0x531a15ff, 0x551a15ff, + 0x561a15ff, 0x581a15ff, 0x5a1b16ff, 0x5b1b16ff, 0x5d1b16ff, 0x5f1b16ff, 0x611c17ff, + 0x621c17ff, 0x641c17ff, 0x661c17ff, 0x681d18ff, 0x691d18ff, 0x6b1d18ff, 0x6d1d18ff, + 0x6f1d19ff, 0x701e19ff, 0x721e19ff, 0x741e19ff, 0x761e1aff, 0x771e1aff, 0x791f1aff, + 0x7b1f1aff, 0x7d1f1bff, 0x7f1f1bff, 0x801f1bff, 0x821f1bff, 0x84201cff, 0x86201cff, + 0x88201cff, 0x89201cff, 0x8b201dff, 0x8d201dff, 0x8f201dff, 0x91211dff, 0x93211eff, + 0x94211eff, 0x96211eff, 0x98211fff, 0x9a211fff, 0x9c211fff, 0x9e211fff, 0xa02120ff, + 0xa12220ff, 0xa32220ff, 0xa52220ff, 0xa72221ff, 0xa92221ff, 0xab2221ff, 0xad2221ff, + 0xaf2222ff, 0xb12222ff, 0xb22222ff, 0xb32422ff, 0xb42622ff, 0xb52821ff, 0xb62a21ff, + 0xb72c21ff, 0xb82d21ff, 0xb92f20ff, 0xba3120ff, 0xbb3220ff, 0xbc341fff, 0xbd351fff, + 0xbe371fff, 0xbf381fff, 0xc03a1eff, 0xc13b1eff, 0xc23d1eff, 0xc33e1dff, 0xc4401dff, + 0xc5411cff, 0xc6421cff, 0xc7441cff, 0xc8451bff, 0xc9471bff, 0xca481aff, 0xcb491aff, + 0xcc4b19ff, 0xcd4c19ff, 0xce4d18ff, 0xcf4f18ff, 0xd05017ff, 0xd15217ff, 0xd25316ff, + 0xd35415ff, 0xd45515ff, 0xd55714ff, 0xd65813ff, 0xd75913ff, 0xd85b12ff, 0xd95c11ff, + 0xda5d10ff, 0xdb5f0fff, 0xdc600eff, 0xdd610dff, 0xde620cff, 0xdf640bff, 0xe06509ff, + 0xe16608ff, 0xe26807ff, 0xe36905ff, 0xe36b05ff, 0xe36d06ff, 0xe46e07ff, 0xe47007ff, + 0xe47208ff, 0xe47408ff, 0xe57609ff, 0xe5770aff, 0xe5790aff, 0xe57b0bff, 0xe57c0cff, + 0xe67e0cff, 0xe6800dff, 0xe6820eff, 0xe6830eff, 0xe6850fff, 0xe6870fff, 0xe78810ff, + 0xe78a11ff, 0xe78c11ff, 0xe78d12ff, 0xe78f13ff, 0xe79113ff, 0xe79214ff, 0xe89415ff, + 0xe89615ff, 0xe89716ff, 0xe89916ff, 0xe89a17ff, 0xe89c18ff, 0xe89e18ff, 0xe89f19ff, + 0xe8a11aff, 0xe8a21aff, 0xe9a41bff, 0xe9a61bff, 0xe9a71cff, 0xe9a91dff, 0xe9aa1dff, + 0xe9ac1eff, 0xe9ae1eff, 0xe9af1fff, 0xe9b120ff, 0xe9b220ff, 0xe9b421ff, 0xe9b522ff, + 0xe9b722ff, 0xe9b923ff, 0xe9ba23ff, 0xe9bc24ff, 0xe9bd25ff, 0xe9bf25ff, 0xe9c026ff, + 0xe9c226ff, 0xe9c327ff, 0xe9c528ff, 0xe9c728ff, 0xe9c829ff, 0xe8ca2aff, 0xe8cb2aff, + 0xe8cd2bff, 0xe8ce2bff, 0xe8d02cff, 0xe8d12dff, 0xe8d32dff, 0xe8d52eff, 0xe8d62fff, + 0xe8d82fff, 0xe7d930ff, 0xe7db30ff, 0xe7dc31ff, 0xe7de32ff, 0xe7df32ff, 0xe7e133ff, + 0xe6e234ff, 0xe6e434ff, 0xe6e535ff, 0xe7e73cff, 0xe9e745ff, 0xeae84eff, 0xece957ff, + 0xedea5eff, 0xeeeb66ff, 0xf0ec6dff, 0xf1ec75ff, 0xf2ed7cff, 0xf3ee83ff, 0xf5ef89ff, + 0xf6f090ff, 0xf7f197ff, 0xf8f19eff, 0xf9f2a4ff, 0xf9f3abff, 0xfaf4b1ff, 0xfbf5b8ff, + 0xfcf6beff, 0xfcf7c5ff, 0xfdf8cbff, 0xfdf9d2ff, 0xfef9d8ff, 0xfefadfff, 0xfefbe5ff, + 0xfffcecff, 0xfffdf2ff, 0xfffef9ff, 0xffffffff, + ], + Self::ExtendedKindLmann => [ + 0x000000ff, 0x050004ff, 0x090009ff, 0x0d010dff, 0x100111ff, 0x130115ff, 0x160118ff, + 0x18011bff, 0x1a011eff, 0x1b0222ff, 0x1c0226ff, 0x1d022aff, 0x1d022eff, 0x1e0232ff, + 0x1e0335ff, 0x1e0339ff, 0x1e033dff, 0x1d0341ff, 0x1d0344ff, 0x1c0348ff, 0x1b044bff, + 0x1b044fff, 0x1a0452ff, 0x190455ff, 0x180458ff, 0x17045cff, 0x16055fff, 0x150562ff, + 0x140565ff, 0x130567ff, 0x12056aff, 0x12056dff, 0x11056fff, 0x0e0573ff, 0x080677ff, + 0x060878ff, 0x060b78ff, 0x060f77ff, 0x061276ff, 0x061674ff, 0x051972ff, 0x051c70ff, + 0x051f6dff, 0x05216bff, 0x052468ff, 0x052665ff, 0x052863ff, 0x052a60ff, 0x052c5eff, + 0x042e5bff, 0x043059ff, 0x043157ff, 0x043355ff, 0x043453ff, 0x043651ff, 0x04374fff, + 0x04394dff, 0x043a4cff, 0x043b4aff, 0x033d49ff, 0x033e47ff, 0x033f46ff, 0x034145ff, + 0x034243ff, 0x034342ff, 0x034441ff, 0x034540ff, 0x03473fff, 0x03483dff, 0x04493cff, + 0x044a3aff, 0x044b38ff, 0x044d37ff, 0x044e35ff, 0x044f33ff, 0x045031ff, 0x04512fff, + 0x04522dff, 0x04542bff, 0x045529ff, 0x045627ff, 0x045724ff, 0x045822ff, 0x04591fff, + 0x045b1dff, 0x045c1aff, 0x055d18ff, 0x055e15ff, 0x055f12ff, 0x05600fff, 0x05610dff, + 0x05620aff, 0x056408ff, 0x056506ff, 0x066605ff, 0x086705ff, 0x0a6805ff, 0x0b6905ff, + 0x0d6a05ff, 0x0f6b05ff, 0x116c05ff, 0x146d05ff, 0x166e05ff, 0x1a6f05ff, 0x1d7005ff, + 0x207005ff, 0x247105ff, 0x287205ff, 0x2b7306ff, 0x2f7406ff, 0x337406ff, 0x377506ff, + 0x3b7606ff, 0x3f7606ff, 0x437706ff, 0x477706ff, 0x4c7806ff, 0x507806ff, 0x547906ff, + 0x587906ff, 0x5c7a06ff, 0x617a06ff, 0x657a06ff, 0x697b06ff, 0x6d7b06ff, 0x717b06ff, + 0x767b06ff, 0x7a7b06ff, 0x7e7b06ff, 0x827b06ff, 0x877b07ff, 0x8b7b07ff, 0x907b07ff, + 0x957a07ff, 0x9a7a07ff, 0xa07908ff, 0xa57808ff, 0xab7708ff, 0xb17608ff, 0xb77509ff, + 0xbd7309ff, 0xc47109ff, 0xca6f0aff, 0xd16c0aff, 0xd8690aff, 0xde660bff, 0xe5620bff, + 0xec5e0bff, 0xf35a0cff, 0xf45b1bff, 0xf55c25ff, 0xf55e2eff, 0xf56034ff, 0xf6623aff, + 0xf6633fff, 0xf66543ff, 0xf66747ff, 0xf6694aff, 0xf66b4dff, 0xf76d4fff, 0xf76f53ff, + 0xf77057ff, 0xf7725bff, 0xf77360ff, 0xf87565ff, 0xf8766aff, 0xf87870ff, 0xf87976ff, + 0xf97a7bff, 0xf97b81ff, 0xf97d87ff, 0xf97e8dff, 0xf97f93ff, 0xf98199ff, 0xf9829eff, + 0xf983a4ff, 0xf984a9ff, 0xf985afff, 0xf986b4ff, 0xf987baff, 0xf989bfff, 0xf98ac4ff, + 0xf98bc9ff, 0xfa8cceff, 0xfa8dd3ff, 0xfa8ed8ff, 0xfa8fddff, 0xfa90e1ff, 0xfa91e6ff, + 0xfa92ebff, 0xfa93efff, 0xfa94f3ff, 0xfa95f8ff, 0xf898faff, 0xf59bfaff, 0xf29ffaff, + 0xefa2fbff, 0xeca5fbff, 0xeaa8fbff, 0xe8abfbff, 0xe6adfbff, 0xe5b0fbff, 0xe3b2fbff, + 0xe2b4fbff, 0xe1b6fbff, 0xe0b8fcff, 0xe0bafcff, 0xdfbcfcff, 0xdfbefcff, 0xdebffcff, + 0xdec1fcff, 0xdec3fcff, 0xdec4fcff, 0xdfc6fcff, 0xdfc7fcff, 0xdfc9fcff, 0xe0cafcff, + 0xe0ccfdff, 0xe1cdfdff, 0xe2cffdff, 0xe2d0fdff, 0xe3d1fdff, 0xe4d3fdff, 0xe5d4fdff, + 0xe5d5fdff, 0xe6d7fdff, 0xe7d8fdff, 0xe7dafdff, 0xe7dbfdff, 0xe8ddfdff, 0xe8defdff, + 0xe8e0fdff, 0xe8e1feff, 0xe9e3feff, 0xe9e4feff, 0xe9e6feff, 0xe9e7feff, 0xe9e9feff, + 0xe9eafeff, 0xeaecfeff, 0xeaedfeff, 0xeaeffeff, 0xebf0feff, 0xebf2feff, 0xecf3feff, + 0xedf5feff, 0xedf6feff, 0xeef7feff, 0xeff9feff, 0xf0fafeff, 0xf2fbfeff, 0xf3fcfeff, + 0xf5fdffff, 0xf8feffff, 0xfbffffff, 0xffffffff, + ], + Self::KindLmann => [ + 0x000000ff, 0x050004ff, 0x090008ff, 0x0d010dff, 0x110110ff, 0x140114ff, 0x160117ff, + 0x19011aff, 0x1b011dff, 0x1d0220ff, 0x1e0223ff, 0x1f0226ff, 0x20022aff, 0x21022dff, + 0x220230ff, 0x230233ff, 0x240336ff, 0x250339ff, 0x25033cff, 0x26033fff, 0x260342ff, + 0x260344ff, 0x270347ff, 0x27044aff, 0x27044dff, 0x270450ff, 0x270453ff, 0x270456ff, + 0x270459ff, 0x27045dff, 0x270560ff, 0x270563ff, 0x260566ff, 0x26056aff, 0x25056dff, + 0x250570ff, 0x240674ff, 0x230677ff, 0x22067bff, 0x21067eff, 0x200681ff, 0x200684ff, + 0x1f0688ff, 0x1e078bff, 0x1d078eff, 0x1c0791ff, 0x1b0794ff, 0x1a0797ff, 0x19079aff, + 0x19079dff, 0x1808a0ff, 0x1808a3ff, 0x1408a6ff, 0x0f08aaff, 0x0809aeff, 0x080cafff, + 0x080fafff, 0x0813afff, 0x0816afff, 0x0819afff, 0x081caeff, 0x0820adff, 0x0823acff, + 0x0826aaff, 0x0829a8ff, 0x082ba6ff, 0x082ea5ff, 0x0831a3ff, 0x0833a0ff, 0x08359eff, + 0x08389cff, 0x073a9aff, 0x073c98ff, 0x073e95ff, 0x074093ff, 0x074291ff, 0x07448fff, + 0x07468dff, 0x07478bff, 0x074989ff, 0x074b87ff, 0x064c85ff, 0x064e84ff, 0x065082ff, + 0x065180ff, 0x06537fff, 0x06547dff, 0x06567bff, 0x06577aff, 0x065878ff, 0x065a77ff, + 0x065b76ff, 0x065d74ff, 0x065e73ff, 0x055f72ff, 0x066071ff, 0x056270ff, 0x05636eff, + 0x05646dff, 0x05666cff, 0x05676bff, 0x05686aff, 0x056969ff, 0x056b68ff, 0x056c67ff, + 0x056d66ff, 0x056e65ff, 0x057064ff, 0x057163ff, 0x057262ff, 0x067360ff, 0x06755fff, + 0x06765eff, 0x06775cff, 0x06785bff, 0x067a59ff, 0x067b58ff, 0x067c56ff, 0x067d54ff, + 0x067f53ff, 0x068051ff, 0x06814fff, 0x06824dff, 0x06844bff, 0x06854aff, 0x078648ff, + 0x068746ff, 0x078943ff, 0x078a41ff, 0x078b3fff, 0x078c3dff, 0x078e3bff, 0x078f38ff, + 0x079036ff, 0x079134ff, 0x079331ff, 0x07942fff, 0x07952cff, 0x07962aff, 0x079727ff, + 0x079925ff, 0x079a22ff, 0x089b1fff, 0x089c1dff, 0x089d1aff, 0x089f17ff, 0x08a014ff, + 0x08a112ff, 0x08a20fff, 0x08a30cff, 0x08a50aff, 0x08a608ff, 0x0ca708ff, 0x0fa808ff, + 0x11a908ff, 0x12aa08ff, 0x14ab08ff, 0x16ad08ff, 0x18ae08ff, 0x1aaf08ff, 0x1db008ff, + 0x20b109ff, 0x23b209ff, 0x26b309ff, 0x29b409ff, 0x2db509ff, 0x30b609ff, 0x34b709ff, + 0x38b809ff, 0x3bb909ff, 0x3fba09ff, 0x43bb09ff, 0x47bc09ff, 0x4bbd09ff, 0x4fbe09ff, + 0x53be09ff, 0x57bf09ff, 0x5bc009ff, 0x5fc109ff, 0x63c109ff, 0x67c209ff, 0x6bc309ff, + 0x6fc409ff, 0x74c409ff, 0x78c509ff, 0x7cc60aff, 0x80c60aff, 0x85c70aff, 0x89c70aff, + 0x8dc80aff, 0x91c80aff, 0x96c90aff, 0x9ac90aff, 0x9eca0aff, 0xa3ca0aff, 0xa7ca0aff, + 0xabcb0aff, 0xafcb0aff, 0xb4cb0aff, 0xb8cc0aff, 0xbccc0aff, 0xc1cc0aff, 0xc5cd0aff, + 0xc9cd0aff, 0xcdcd0aff, 0xd1cd0aff, 0xd6cd0aff, 0xdacd0bff, 0xdfcd0bff, 0xe4cd0bff, + 0xe9cd0bff, 0xedcd0bff, 0xf3cd0cff, 0xf6cc39ff, 0xf7cd56ff, 0xf8cd69ff, 0xf9ce77ff, + 0xf9cf83ff, 0xfacf8dff, 0xfad095ff, 0xfad19dff, 0xfbd2a3ff, 0xfbd3a9ff, 0xfbd4aeff, + 0xfbd6b3ff, 0xfcd7b8ff, 0xfcd8bcff, 0xfcd9c0ff, 0xfcdac3ff, 0xfcdcc7ff, 0xfcddcaff, + 0xfddecdff, 0xfde0d0ff, 0xfde1d3ff, 0xfde2d5ff, 0xfde3d8ff, 0xfde5daff, 0xfde6ddff, + 0xfde8dfff, 0xfee9e1ff, 0xfeeae3ff, 0xfeece5ff, 0xfeede7ff, 0xfeeee9ff, 0xfef0ebff, + 0xfef1edff, 0xfef2efff, 0xfef4f1ff, 0xfef5f3ff, 0xfef7f5ff, 0xfff8f6ff, 0xfff9f8ff, + 0xfffbfaff, 0xfffcfcff, 0xfffefdff, 0xffffffff, + ], + Self::SmoothCoolWarm => [ + 0x3b4cc0ff, 0x3c4ec2ff, 0x3d50c3ff, 0x3e51c5ff, 0x3f53c7ff, 0x4055c8ff, 0x4257caff, + 0x4358cbff, 0x445accff, 0x455cceff, 0x465ecfff, 0x485fd1ff, 0x4961d2ff, 0x4a63d4ff, + 0x4b64d5ff, 0x4c66d6ff, 0x4e68d8ff, 0x4f6ad9ff, 0x506bdaff, 0x516ddbff, 0x536fddff, + 0x5470deff, 0x5572dfff, 0x5674e0ff, 0x5875e2ff, 0x5977e3ff, 0x5a78e4ff, 0x5b7ae5ff, + 0x5d7ce6ff, 0x5e7de7ff, 0x5f7fe8ff, 0x6181e9ff, 0x6282eaff, 0x6384ebff, 0x6585ecff, + 0x6687edff, 0x6788eeff, 0x698aefff, 0x6a8cf0ff, 0x6b8df0ff, 0x6d8ff1ff, 0x6e90f2ff, + 0x6f92f3ff, 0x7193f4ff, 0x7295f4ff, 0x7396f5ff, 0x7598f6ff, 0x7699f6ff, 0x779af7ff, + 0x799cf8ff, 0x7a9df8ff, 0x7b9ff9ff, 0x7da0f9ff, 0x7ea2faff, 0x80a3faff, 0x81a4fbff, + 0x82a6fbff, 0x84a7fcff, 0x85a8fcff, 0x86aafcff, 0x88abfdff, 0x89acfdff, 0x8baefdff, + 0x8caffeff, 0x8db0feff, 0x8fb1feff, 0x90b2feff, 0x92b4feff, 0x93b5ffff, 0x94b6ffff, + 0x96b7ffff, 0x97b8ffff, 0x99baffff, 0x9abbffff, 0x9bbcffff, 0x9dbdffff, 0x9ebeffff, + 0x9fbfffff, 0xa1c0ffff, 0xa2c1ffff, 0xa3c2feff, 0xa5c3feff, 0xa6c4feff, 0xa8c5feff, + 0xa9c6feff, 0xaac7fdff, 0xacc8fdff, 0xadc9fdff, 0xaec9fcff, 0xb0cafcff, 0xb1cbfcff, + 0xb2ccfbff, 0xb4cdfbff, 0xb5cefaff, 0xb6cefaff, 0xb7cff9ff, 0xb9d0f9ff, 0xbad1f8ff, + 0xbbd1f8ff, 0xbdd2f7ff, 0xbed3f6ff, 0xbfd3f6ff, 0xc0d4f5ff, 0xc1d4f4ff, 0xc3d5f4ff, + 0xc4d6f3ff, 0xc5d6f2ff, 0xc6d7f1ff, 0xc8d7f1ff, 0xc9d8f0ff, 0xcad8efff, 0xcbd8eeff, + 0xccd9edff, 0xcdd9ecff, 0xcedaebff, 0xd0daeaff, 0xd1dae9ff, 0xd2dbe8ff, 0xd3dbe7ff, + 0xd4dbe6ff, 0xd5dbe5ff, 0xd6dce4ff, 0xd7dce3ff, 0xd8dce2ff, 0xd9dce1ff, 0xdadce0ff, + 0xdbdddeff, 0xdcddddff, 0xdddcdcff, 0xdedcdbff, 0xdfdcd9ff, 0xe1dbd8ff, 0xe2dad6ff, + 0xe3dad5ff, 0xe4d9d3ff, 0xe5d9d2ff, 0xe5d8d1ff, 0xe6d8cfff, 0xe7d7ceff, 0xe8d6ccff, + 0xe9d6cbff, 0xead5c9ff, 0xebd4c8ff, 0xebd3c6ff, 0xecd3c5ff, 0xedd2c3ff, 0xeed1c2ff, + 0xeed0c0ff, 0xefcfbfff, 0xefcebdff, 0xf0cebbff, 0xf1cdbaff, 0xf1ccb8ff, 0xf2cbb7ff, + 0xf2cab5ff, 0xf3c9b4ff, 0xf3c8b2ff, 0xf4c7b1ff, 0xf4c6afff, 0xf4c5adff, 0xf5c4acff, + 0xf5c3aaff, 0xf5c1a9ff, 0xf6c0a7ff, 0xf6bfa6ff, 0xf6bea4ff, 0xf6bda2ff, 0xf7bca1ff, + 0xf7ba9fff, 0xf7b99eff, 0xf7b89cff, 0xf7b79bff, 0xf7b599ff, 0xf7b497ff, 0xf7b396ff, + 0xf7b194ff, 0xf7b093ff, 0xf7af91ff, 0xf7ad90ff, 0xf7ac8eff, 0xf7ab8cff, 0xf7a98bff, + 0xf7a889ff, 0xf7a688ff, 0xf6a586ff, 0xf6a385ff, 0xf6a283ff, 0xf6a081ff, 0xf59f80ff, + 0xf59d7eff, 0xf59c7dff, 0xf49a7bff, 0xf4997aff, 0xf49778ff, 0xf39577ff, 0xf39475ff, + 0xf29274ff, 0xf29072ff, 0xf18f71ff, 0xf18d6fff, 0xf08b6eff, 0xf08a6cff, 0xef886bff, + 0xee8669ff, 0xee8568ff, 0xed8366ff, 0xed8165ff, 0xec7f63ff, 0xeb7d62ff, 0xea7c60ff, + 0xea7a5fff, 0xe9785dff, 0xe8765cff, 0xe7745bff, 0xe67259ff, 0xe57058ff, 0xe56f56ff, + 0xe46d55ff, 0xe36b54ff, 0xe26952ff, 0xe16751ff, 0xe0654fff, 0xdf634eff, 0xde614dff, + 0xdd5f4bff, 0xdc5d4aff, 0xdb5b49ff, 0xda5947ff, 0xd85646ff, 0xd75445ff, 0xd65244ff, + 0xd55042ff, 0xd44e41ff, 0xd34c40ff, 0xd1493eff, 0xd0473dff, 0xcf453cff, 0xce433bff, + 0xcc4039ff, 0xcb3e38ff, 0xca3b37ff, 0xc83936ff, 0xc73635ff, 0xc63434ff, 0xc43132ff, + 0xc32e31ff, 0xc12b30ff, 0xc0282fff, 0xbf252eff, 0xbd222dff, 0xbc1e2cff, 0xba1a2bff, + 0xb91629ff, 0xb71128ff, 0xb60b27ff, 0xb40426ff, + ], + }; + + Color::create_palette(&xs) + .try_into() + .expect("Vector length is not 256") + } +} diff --git a/src/core/dataloader.rs b/src/misc/dataloader.rs similarity index 78% rename from src/core/dataloader.rs rename to src/misc/dataloader.rs index 5c35e54..724e5a7 100644 --- a/src/core/dataloader.rs +++ b/src/misc/dataloader.rs @@ -1,18 +1,18 @@ use anyhow::{anyhow, Result}; use image::DynamicImage; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::ProgressBar; +use log::{info, warn}; use std::collections::VecDeque; use std::path::{Path, PathBuf}; use std::sync::mpsc; +#[cfg(feature = "ffmpeg")] use video_rs::{ encode::{Encoder, Settings}, time::Time, Decoder, Url, }; -use crate::{ - build_progress_bar, string_now, Dir, Hub, Location, MediaType, CHECK_MARK, CROSS_MARK, -}; +use crate::{build_progress_bar, Hub, Location, MediaType}; type TempReturnType = (Vec, Vec); @@ -28,21 +28,25 @@ impl Iterator for DataLoaderIterator { fn next(&mut self) -> Option { match &self.progress_bar { None => self.receiver.recv().ok(), - Some(progress_bar) => match self.receiver.recv().ok() { - Some(item) => { - progress_bar.inc(self.batch_size); - Some(item) - } - None => { - progress_bar.set_prefix(" Iterated"); - progress_bar.set_style( - indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_FINISH_2) - .unwrap(), + Some(progress_bar) => { + match self.receiver.recv().ok() { + Some(item) => { + progress_bar.inc(self.batch_size); + Some(item) + } + None => { + progress_bar.set_prefix("Iterated"); + progress_bar.set_style( + match indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_FINISH_2) { + Ok(x) => x, + Err(err) => panic!("Failed to set style for progressbar in `DataLoaderIterator`: {}", err), + }, ); - progress_bar.finish(); - None + progress_bar.finish(); + None + } } - }, + } } } } @@ -55,7 +59,7 @@ impl IntoIterator for DataLoader { let progress_bar = if self.with_pb { build_progress_bar( self.nf, - " Iterating", + "Iterating", Some(&format!("{:?}", self.media_type)), crate::PROGRESS_BAR_STYLE_CYAN_2, ) @@ -93,6 +97,7 @@ pub struct DataLoader { receiver: mpsc::Receiver, /// Video decoder for handling video or stream data. + #[cfg(feature = "ffmpeg")] decoder: Option, /// Number of images or frames; `u64::MAX` is used for live streams (indicating no limit). @@ -102,10 +107,18 @@ pub struct DataLoader { with_pb: bool, } +impl TryFrom<&str> for DataLoader { + type Error = anyhow::Error; + + fn try_from(str: &str) -> Result { + Self::new(str) + } +} + impl DataLoader { pub fn new(source: &str) -> Result { - let span = tracing::span!(tracing::Level::INFO, "DataLoader-new"); - let _guard = span.enter(); + // TODO: multi-types + // Vec<&str> // Number of frames or stream let mut nf = 0; @@ -147,6 +160,21 @@ impl DataLoader { } // video decoder + #[cfg(not(feature = "ffmpeg"))] + { + match &media_type { + MediaType::Video(Location::Local) + | MediaType::Video(Location::Remote) + | MediaType::Stream => { + anyhow::bail!( + "Video processing requires the features: `ffmpeg`. \ + \nConsider enabling them by passing, e.g., `--features ffmpeg`" + ); + } + _ => {} + }; + } + #[cfg(feature = "ffmpeg")] let decoder = match &media_type { MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?), MediaType::Video(Location::Remote) | MediaType::Stream => { @@ -157,6 +185,7 @@ impl DataLoader { }; // video & stream frames + #[cfg(feature = "ffmpeg")] if let Some(decoder) = &decoder { nf = match decoder.frames() { Err(_) => u64::MAX, @@ -166,7 +195,7 @@ impl DataLoader { } // summary - tracing::info!("{} Found {:?} x{}", CHECK_MARK, media_type, nf); + info!("Found {:?} x{}", media_type, nf); Ok(DataLoader { paths, @@ -174,6 +203,7 @@ impl DataLoader { bound: 50, receiver: mpsc::sync_channel(1).1, batch_size: 1, + #[cfg(feature = "ffmpeg")] decoder, nf, with_pb: true, @@ -190,6 +220,11 @@ impl DataLoader { self } + pub fn with_batch_size(mut self, x: usize) -> Self { + self.batch_size = x; + self + } + pub fn with_progress_bar(mut self, x: bool) -> Self { self.with_pb = x; self @@ -201,11 +236,19 @@ impl DataLoader { let batch_size = self.batch_size; let data = self.paths.take().unwrap_or_default(); let media_type = self.media_type.clone(); + #[cfg(feature = "ffmpeg")] let decoder = self.decoder.take(); // Spawn the producer thread std::thread::spawn(move || { - DataLoader::producer_thread(sender, data, batch_size, media_type, decoder); + DataLoader::producer_thread( + sender, + data, + batch_size, + media_type, + #[cfg(feature = "ffmpeg")] + decoder, + ); }); Ok(self) @@ -216,10 +259,8 @@ impl DataLoader { mut data: VecDeque, batch_size: usize, media_type: MediaType, - mut decoder: Option, + #[cfg(feature = "ffmpeg")] mut decoder: Option, ) { - let span = tracing::span!(tracing::Level::INFO, "DataLoader-producer-thread"); - let _guard = span.enter(); let mut yis: Vec = Vec::with_capacity(batch_size); let mut yps: Vec = Vec::with_capacity(batch_size); @@ -228,7 +269,7 @@ impl DataLoader { while let Some(path) = data.pop_front() { match Self::try_read(&path) { Err(err) => { - tracing::warn!("{} {:?} | {:?}", CROSS_MARK, path, err); + warn!("{:?} | {:?}", path, err); continue; } Ok(img) => { @@ -245,6 +286,7 @@ impl DataLoader { } } } + #[cfg(feature = "ffmpeg")] MediaType::Video(_) | MediaType::Stream => { if let Some(decoder) = decoder.as_mut() { let (w, h) = decoder.size(); @@ -279,12 +321,12 @@ impl DataLoader { } } } - _ => todo!(), + _ => unimplemented!(), } // Deal with remaining data if !yis.is_empty() && sender.send((yis, yps)).is_err() { - tracing::info!("Receiver dropped, stopping production"); + info!("Receiver dropped, stopping production"); } } @@ -319,13 +361,30 @@ impl DataLoader { // try to fetch from hub or local cache if !path.exists() { - let p = Hub::new()?.fetch(path.to_str().unwrap())?.commit()?; + let p = Hub::default().try_fetch(path.to_str().unwrap())?; path = PathBuf::from(&p); } let img = Self::read_into_rgb8(path)?; Ok(DynamicImage::from(img)) } + pub fn try_read_batch + std::fmt::Debug>( + paths: &[P], + ) -> Result> { + let images = paths + .iter() + .filter_map(|path| match Self::try_read(path) { + Ok(img) => Some(img), + Err(err) => { + warn!("Failed to read from: {:?}. Error: {:?}", path, err); + None + } + }) + .collect(); + + Ok(images) + } + fn read_into_rgb8>(path: P) -> Result { let path = path.as_ref(); let img = image::ImageReader::open(path) @@ -357,6 +416,7 @@ impl DataLoader { } /// Convert images into a video + #[cfg(feature = "ffmpeg")] pub fn is2v>(source: P, subs: &[&str], fps: usize) -> Result<()> { let paths = Self::load_from_folder(source.as_ref())?; if paths.is_empty() { @@ -364,12 +424,12 @@ impl DataLoader { } let mut encoder = None; let mut position = Time::zero(); - let saveout = Dir::Currnet + let saveout = crate::Dir::Current .raw_path_with_subs(subs)? - .join(format!("{}.mp4", string_now("-"))); + .join(format!("{}.mp4", crate::string_now("-"))); let pb = build_progress_bar( paths.len() as u64, - " Converting", + "Converting", Some(&format!("{:?}", MediaType::Video(Location::Local))), crate::PROGRESS_BAR_STYLE_CYAN_2, )?; @@ -404,9 +464,9 @@ impl DataLoader { } // update - pb.set_prefix(" Converted"); + pb.set_prefix("Converted"); pb.set_message(saveout.to_str().unwrap_or_default().to_string()); - pb.set_style(ProgressStyle::with_template( + pb.set_style(indicatif::ProgressStyle::with_template( crate::PROGRESS_BAR_STYLE_FINISH_4, )?); pb.finish(); diff --git a/src/misc/device.rs b/src/misc/device.rs new file mode 100644 index 0000000..e1029e1 --- /dev/null +++ b/src/misc/device.rs @@ -0,0 +1,63 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum Device { + Auto(usize), + Cpu(usize), + Cuda(usize), + TensorRT(usize), + CoreML(usize), + // Cann(usize), + // Acl(usize), + // Rocm(usize), + // Rknpu(usize), + // Openvino(usize), + // Onednn(usize), +} + +impl std::fmt::Display for Device { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Auto(i) => format!("auto:{}", i), + Self::Cpu(i) => format!("cpu:{}", i), + Self::Cuda(i) => format!("cuda:{}", i), + Self::TensorRT(i) => format!("tensorrt:{}", i), + Self::CoreML(i) => format!("mps:{}", i), + }; + write!(f, "{}", x) + } +} + +impl TryFrom<&str> for Device { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + // device and its id + let d_id: Vec<&str> = s.trim().split(':').collect(); + let (d, id) = match d_id.len() { + 1 => (d_id[0], 0), + 2 => (d_id[0], d_id[1].parse::().unwrap_or(0)), + _ => anyhow::bail!( + "Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`" + ), + }; + // TODO: device-id checking + match d.to_lowercase().as_str() { + "cpu" => Ok(Self::Cpu(id)), + "cuda" => Ok(Self::Cuda(id)), + "trt" | "tensorrt" => Ok(Self::TensorRT(id)), + "coreml" | "mps" => Ok(Self::CoreML(id)), + _ => anyhow::bail!("Unsupported device str: {s:?}."), + } + } +} + +impl Device { + pub fn id(&self) -> usize { + match self { + Device::Auto(i) => *i, + Device::Cpu(i) => *i, + Device::Cuda(i) => *i, + Device::TensorRT(i) => *i, + Device::CoreML(i) => *i, + } + } +} diff --git a/src/core/dir.rs b/src/misc/dir.rs similarity index 97% rename from src/core/dir.rs rename to src/misc/dir.rs index 7c36c56..9b260ff 100644 --- a/src/core/dir.rs +++ b/src/misc/dir.rs @@ -4,7 +4,7 @@ pub enum Dir { Home, Cache, Config, - Currnet, + Current, Document, Data, Download, @@ -15,7 +15,7 @@ pub enum Dir { impl Dir { pub fn saveout(subs: &[&str]) -> anyhow::Result { - Self::Currnet.raw_path_with_subs(subs) + Self::Current.raw_path_with_subs(subs) } /// Retrieves the base path for the specified directory type, optionally appending the `usls` subdirectory. @@ -30,7 +30,7 @@ impl Dir { Dir::Home => dirs::home_dir(), Dir::Cache => dirs::cache_dir(), Dir::Config => dirs::config_dir(), - Dir::Currnet => std::env::current_dir().ok(), + Dir::Current => std::env::current_dir().ok(), _ => None, }; diff --git a/src/misc/dtype.rs b/src/misc/dtype.rs new file mode 100644 index 0000000..81f0d50 --- /dev/null +++ b/src/misc/dtype.rs @@ -0,0 +1,114 @@ +use ort::tensor::TensorElementType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum DType { + Auto, + Int8, + Int16, + Int32, + Int64, + Uint8, + Uint16, + Uint32, + Uint64, + Fp16, + Fp32, + Fp64, + Bf16, + Bool, + String, + Bnb4, + Q4, + Q4f16, +} + +impl TryFrom<&str> for DType { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "auto" | "dyn" => Ok(Self::Auto), + "u8" | "uint8" => Ok(Self::Uint8), + "u16" | "uint16" => Ok(Self::Uint16), + "u32" | "uint32" => Ok(Self::Uint32), + "u64" | "uint64" => Ok(Self::Uint64), + "i8" | "int8" => Ok(Self::Int8), + "i16" | "int=16" => Ok(Self::Int16), + "i32" | "int32" => Ok(Self::Int32), + "i64" | "int64" => Ok(Self::Int64), + "f16" | "fp16" => Ok(Self::Fp16), + "f32" | "fp32" => Ok(Self::Fp32), + "f64" | "fp64" => Ok(Self::Fp64), + "b16" | "bf16" => Ok(Self::Bf16), + "q4f16" => Ok(Self::Q4f16), + "q4" => Ok(Self::Q4), + "bnb4" => Ok(Self::Bnb4), + x => anyhow::bail!("Unsupported Model DType: {}", x), + } + } +} + +impl std::fmt::Display for DType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Auto => "auto", + Self::Int8 => "int8", + Self::Int16 => "int16", + Self::Int32 => "int32", + Self::Int64 => "int64", + Self::Uint8 => "uint8", + Self::Uint16 => "uint16", + Self::Uint32 => "uint32", + Self::Uint64 => "uint64", + Self::Fp16 => "fp16", + Self::Fp32 => "fp32", + Self::Fp64 => "fp64", + Self::Bf16 => "bf16", + Self::String => "string", + Self::Bool => "bool", + Self::Bnb4 => "bnb4", + Self::Q4 => "q4", + Self::Q4f16 => "q4f16", + }; + write!(f, "{}", x) + } +} + +impl DType { + pub fn to_ort(&self) -> TensorElementType { + match self { + Self::Int8 => TensorElementType::Int8, + Self::Int16 => TensorElementType::Int16, + Self::Int32 => TensorElementType::Int32, + Self::Int64 => TensorElementType::Int64, + Self::Uint8 => TensorElementType::Uint8, + Self::Uint16 => TensorElementType::Uint16, + Self::Uint32 => TensorElementType::Uint32, + Self::Uint64 => TensorElementType::Uint64, + Self::Fp16 => TensorElementType::Float16, + Self::Fp32 => TensorElementType::Float32, + Self::Fp64 => TensorElementType::Float64, + Self::Bf16 => TensorElementType::Bfloat16, + _ => todo!(), + } + } + + pub fn from_ort(dtype: &TensorElementType) -> Self { + match dtype { + TensorElementType::Int8 => Self::Int8, + TensorElementType::Int16 => Self::Int16, + TensorElementType::Int32 => Self::Int32, + TensorElementType::Int64 => Self::Int64, + TensorElementType::Uint8 => Self::Uint8, + TensorElementType::Uint16 => Self::Uint16, + TensorElementType::Uint32 => Self::Uint32, + TensorElementType::Uint64 => Self::Uint64, + TensorElementType::Float16 => Self::Fp16, + TensorElementType::Float32 => Self::Fp32, + TensorElementType::Float64 => Self::Fp64, + TensorElementType::Bfloat16 => Self::Bf16, + TensorElementType::String => Self::String, + TensorElementType::Bool => Self::Bool, + } + } +} diff --git a/src/core/dynconf.rs b/src/misc/dynconf.rs similarity index 78% rename from src/core/dynconf.rs rename to src/misc/dynconf.rs index 3c14c68..fa715d2 100644 --- a/src/core/dynconf.rs +++ b/src/misc/dynconf.rs @@ -2,27 +2,23 @@ use std::ops::Index; /// Dynamic Confidences #[derive(Clone, PartialEq, PartialOrd)] -pub struct DynConf { - confs: Vec, -} +pub struct DynConf(Vec); impl Default for DynConf { fn default() -> Self { - Self { - confs: vec![0.4f32], - } + Self(vec![0.4f32]) } } impl std::fmt::Debug for DynConf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("").field("DynConf", &self.confs).finish() + f.debug_struct("").field("DynConf", &self.0).finish() } } impl std::fmt::Display for DynConf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_list().entries(self.confs.iter()).finish() + f.debug_list().entries(self.0.iter()).finish() } } @@ -30,7 +26,7 @@ impl Index for DynConf { type Output = f32; fn index(&self, i: usize) -> &Self::Output { - &self.confs[i] + &self.0[i] } } @@ -50,6 +46,6 @@ impl DynConf { confs }; - Self { confs } + Self(confs) } } diff --git a/src/misc/engine.rs b/src/misc/engine.rs new file mode 100644 index 0000000..a99ed27 --- /dev/null +++ b/src/misc/engine.rs @@ -0,0 +1,748 @@ +use aksr::Builder; +use anyhow::Result; +use half::{bf16, f16}; +use log::{error, info, warn}; +use ndarray::{Array, IxDyn}; +#[allow(unused_imports)] +use ort::{ + execution_providers::ExecutionProvider, + session::{ + builder::{GraphOptimizationLevel, SessionBuilder}, + Session, SessionInputValue, + }, + tensor::TensorElementType, + value::{DynValue, Value}, +}; +use prost::Message; +use std::collections::HashSet; + +use crate::{ + build_progress_bar, elapsed, human_bytes, onnx, DType, Device, Iiix, MinOptMax, Ops, Ts, Xs, X, +}; + +/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. +#[derive(Builder, Debug, Clone)] +pub struct OrtTensorAttr { + pub names: Vec, + pub dtypes: Vec, + pub dimss: Vec>, +} + +#[derive(Debug)] +pub struct OnnxIo { + pub inputs: OrtTensorAttr, + pub outputs: OrtTensorAttr, + pub session: Session, + pub proto: onnx::ModelProto, +} + +#[derive(Debug, Builder)] +pub struct Engine { + pub file: String, + pub spec: String, + pub device: Device, + pub trt_fp16: bool, + #[args(inc = true)] + pub iiixs: Vec, + #[args(alias = "parameters")] + pub params: Option, + #[args(alias = "memory")] + pub wbmems: Option, + pub inputs_minoptmax: Vec>, + pub onnx: Option, + pub ts: Ts, + pub num_dry_run: usize, +} + +impl Default for Engine { + fn default() -> Self { + Self { + file: Default::default(), + device: Device::Cpu(0), + trt_fp16: false, + spec: Default::default(), + iiixs: Default::default(), + num_dry_run: 3, + params: None, + wbmems: None, + inputs_minoptmax: vec![], + onnx: None, + ts: Ts::default(), + } + } +} + +impl Engine { + pub fn build(mut self) -> Result { + let name = format!("[{}] ort_initialization", self.spec); + elapsed!(&name, self.ts, { + let proto = Self::load_onnx(self.file())?; + let graph = match &proto.graph { + Some(graph) => graph, + None => { + anyhow::bail!( + "No graph found in this proto. Invalid ONNX model: {}", + self.file() + ) + } + }; + + // params & mems + let byte_alignment = 16; + let mut params: usize = 0; + let mut wbmems: usize = 0; + let mut initializer_names: HashSet<&str> = HashSet::new(); + if !graph.initializer.is_empty() { + // from initializer + for tensor_proto in graph.initializer.iter() { + initializer_names.insert(&tensor_proto.name); + let param = tensor_proto.dims.iter().product::() as usize; + params += param; + let param = Ops::make_divisible(param, byte_alignment); + let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize); + let wbmem = param * n; + wbmems += wbmem; + } + } else { + // from node, workaround + for node in &graph.node { + for attr in &node.attribute { + if let Some(tensor) = &attr.t { + let param = tensor.dims.iter().product::() as usize; + params += param; + let param = Ops::make_divisible(param, byte_alignment); + let n = Self::nbytes_from_onnx_dtype_id(tensor.data_type as usize); + let wbmem = param * n; + wbmems += wbmem; + } + } + } + } + self.params = Some(params); + self.wbmems = Some(wbmems); + + // inputs & outputs + let inputs = Self::io_from_onnx_value_info(&initializer_names, &graph.input)?; + let outputs = Self::io_from_onnx_value_info(&initializer_names, &graph.output)?; + self.inputs_minoptmax = Self::build_ort_inputs(&inputs, self.iiixs())?; + + // session + ort::init().commit()?; + let session = self.build_session(&inputs)?; + + // onnxio + self.onnx = Some(OnnxIo { + inputs, + outputs, + proto, + session, + }); + }); + self.dry_run()?; + self.info(); + + Ok(self) + } + + pub fn dry_run(&mut self) -> Result<()> { + if self.num_dry_run > 0 { + // pb + let pb = build_progress_bar( + self.num_dry_run as u64, + "DryRun", + Some(self.spec()), + crate::PROGRESS_BAR_STYLE_CYAN_2, + )?; + + // dummy + let mut xs = Vec::new(); + for i in self.inputs_minoptmax().iter() { + let mut x: Vec = Vec::new(); + for i_ in i.iter() { + x.push(i_.opt()); + } + let x: Array = Array::ones(x).into_dyn(); + xs.push(X::from(x)); + } + let xs = Xs::from(xs); + + // run + for i in 0..self.num_dry_run { + pb.inc(1); + let name = format!("[{}] ort_dry_run_{}", self.spec, i); + elapsed!(&name, self.ts, { + self.run(xs.clone())?; + }); + } + + // update + pb.set_message(format!( + "{}({}) on {:?}", + self.spec, + match self.params { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + self.device, + )); + pb.set_style(indicatif::ProgressStyle::with_template( + crate::PROGRESS_BAR_STYLE_FINISH, + )?); + pb.finish(); + } + Ok(()) + } + + pub fn run(&mut self, xs: Xs) -> Result { + let mut ys = xs.derive(); + if let Some(onnx) = &self.onnx { + // alignment + let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, { + let mut xs_ = Vec::new(); + for (dtype, x) in onnx.inputs.dtypes.iter().zip(xs.into_iter()) { + xs_.push(Into::>::into(Self::preprocess( + x, dtype, + )?)); + } + xs_ + }); + + // run + let outputs = elapsed!( + &format!("[{}] ort_inference", self.spec), + self.ts, + onnx.session.run(&xs_[..])? + ); + + // extract + elapsed!(&format!("[{}] ort_postprocessing", self.spec), self.ts, { + for (dtype, name) in onnx.outputs.dtypes.iter().zip(onnx.outputs.names.iter()) { + let y = Self::postprocess(&outputs[name.as_str()], dtype)?; + ys.push_kv(name.as_str(), X::from(y))?; + } + }); + Ok(ys) + } else { + anyhow::bail!("Failed to run with ONNXRuntime. No model info found."); + } + } + + fn preprocess(x: &X, dtype: &TensorElementType) -> Result { + let x = match dtype { + TensorElementType::Float32 => Value::from_array(x.view())?.into_dyn(), + TensorElementType::Float16 => { + Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() + } + TensorElementType::Float64 => Value::from_array(x.view())?.into_dyn(), + TensorElementType::Bfloat16 => { + Value::from_array(x.mapv(bf16::from_f32).view())?.into_dyn() + } + TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn(), + TensorElementType::Int16 => { + Value::from_array(x.mapv(|x_| x_ as i16).view())?.into_dyn() + } + TensorElementType::Int32 => { + Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() + } + TensorElementType::Int64 => { + Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() + } + TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn(), + TensorElementType::Uint16 => { + Value::from_array(x.mapv(|x_| x_ as u16).view())?.into_dyn() + } + TensorElementType::Uint32 => { + Value::from_array(x.mapv(|x_| x_ as u32).view())?.into_dyn() + } + TensorElementType::Uint64 => { + Value::from_array(x.mapv(|x_| x_ as u64).view())?.into_dyn() + } + TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn(), + _ => unimplemented!(), + }; + + Ok(x) + } + + fn postprocess(x: &DynValue, dtype: &TensorElementType) -> Result> { + fn _extract_and_convert(x: &DynValue, map_fn: impl Fn(T) -> f32) -> Array + where + T: Clone + 'static + ort::tensor::PrimitiveTensorElementType, + { + match x.try_extract_tensor::() { + Err(err) => { + error!("Failed to extract from ort outputs: {:?}", err); + Array::zeros(0).into_dyn() + } + Ok(x) => x.view().mapv(map_fn).into_owned(), + } + } + let x = match dtype { + TensorElementType::Float32 => _extract_and_convert::(x, |x| x), + TensorElementType::Float16 => _extract_and_convert::(x, f16::to_f32), + TensorElementType::Bfloat16 => _extract_and_convert::(x, bf16::to_f32), + TensorElementType::Float64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int32 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int16 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int8 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint32 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint16 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint8 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Bool => _extract_and_convert::(x, |x| x as u8 as f32), + _ => return Err(anyhow::anyhow!("Unsupported ort tensor type: {:?}", dtype)), + }; + + Ok(x) + } + + #[allow(unused_variables)] + fn build_session(&mut self, inputs: &OrtTensorAttr) -> Result { + #[allow(unused_mut)] + let mut builder = Session::builder()?; + let compile_help = "Please compile ONNXRuntime with #EP"; + let feature_help = "#EP EP requires the features: `#FEATURE`. \ + \nConsider enabling them by passing, e.g., `--features #FEATURE`"; + + match self.device { + Device::TensorRT(id) => { + #[cfg(not(feature = "trt"))] + { + anyhow::bail!(feature_help + .replace("#EP", "TensorRT") + .replace("#FEATURE", "trt")); + } + + #[cfg(feature = "trt")] + { + // generate shapes + let mut spec_min = String::new(); + let mut spec_opt = String::new(); + let mut spec_max = String::new(); + for (i, name) in inputs.names.iter().enumerate() { + if i != 0 { + spec_min.push(','); + spec_opt.push(','); + spec_max.push(','); + } + let mut s_min = format!("{}:", name); + let mut s_opt = format!("{}:", name); + let mut s_max = format!("{}:", name); + for d in self.inputs_minoptmax[i].iter() { + let min_ = &format!("{}x", d.min()); + let opt_ = &format!("{}x", d.opt()); + let max_ = &format!("{}x", d.max()); + s_min += min_; + s_opt += opt_; + s_max += max_; + } + s_min.pop(); + s_opt.pop(); + s_max.pop(); + spec_min += &s_min; + spec_opt += &s_opt; + spec_max += &s_max; + } + + let p = crate::Dir::Cache.path_with_subs(&["trt-cache"])?; + let ep = ort::execution_providers::TensorRTExecutionProvider::default() + .with_device_id(id as i32) + .with_fp16(self.trt_fp16) + .with_engine_cache(true) + .with_engine_cache_path(p.to_str().unwrap()) + .with_timing_cache(false) + .with_profile_min_shapes(spec_min) + .with_profile_opt_shapes(spec_opt) + .with_profile_max_shapes(spec_max); + + match ep.is_available() { + Ok(true) => { + info!( + "Initial model serialization with TensorRT may require a wait..." + ); + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register TensorRT: {}", err) + })?; + } + _ => { + anyhow::bail!(compile_help.replace("#EP", "TensorRT")) + } + } + } + } + Device::Cuda(id) => { + #[cfg(not(feature = "cuda"))] + { + anyhow::bail!(feature_help + .replace("#EP", "CUDA") + .replace("#FEATURE", "cuda")); + } + + #[cfg(feature = "cuda")] + { + let ep = ort::execution_providers::CUDAExecutionProvider::default() + .with_device_id(id as i32); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register CUDA: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "CUDA")), + } + } + } + Device::CoreML(id) => { + #[cfg(not(feature = "mps"))] + { + anyhow::bail!(feature_help + .replace("#EP", "CoreML") + .replace("#FEATURE", "mps")); + } + #[cfg(feature = "mps")] + { + let ep = ort::execution_providers::CoreMLExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register CoreML: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "CoreML")), + } + } + } + _ => { + let ep = ort::execution_providers::CPUExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder) + .map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "Cpu")), + } + } + } + + // session + let session = builder + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(std::thread::available_parallelism()?.get())? + .commit_from_file(self.file())?; + + Ok(session) + } + + fn build_ort_inputs(xs: &OrtTensorAttr, iiixs: &[Iiix]) -> Result>> { + // init + let mut ys: Vec> = xs + .dimss + .iter() + .map(|dims| dims.iter().map(|&x| MinOptMax::from(x)).collect()) + .collect(); + + // update from customized + for iiix in iiixs.iter() { + if let Some(x) = xs.dimss.get(iiix.i).and_then(|dims| dims.get(iiix.ii)) { + // dynamic + if *x == 0 { + ys[iiix.i][iiix.ii] = iiix.x.clone(); + } + } else { + anyhow::bail!( + "Cannot retrieve the {}-th dimension of the {}-th input.", + iiix.ii, + iiix.i, + ); + } + } + + // set batch size <- i00 + let batch_size: MinOptMax = if ys[0][0].is_dyn() { + 1.into() + } else { + ys[0][0].clone() + }; + + // deal with the dynamic axis + ys.iter_mut().enumerate().for_each(|(i, xs)| { + xs.iter_mut().enumerate().for_each(|(ii, x)| { + if x.is_dyn() { + let z = if ii == 0 { + batch_size.clone() + } else { + let z = MinOptMax::from(1); + warn!( + "Using dynamic shapes in inputs without specifying it: the {}-th input, the {}-th dimension. \ + Using {:?} by default. You should make it clear when using TensorRT.", + i + 1, ii + 1, z + ); + z + }; + *x = z; + } + }); + }); + + Ok(ys) + } + + #[allow(dead_code)] + fn nbytes_from_onnx_dtype_id(x: usize) -> usize { + match x { + 7 | 11 | 13 => 8, // i64, f64, u64 + 1 | 6 | 12 => 4, // f32, i32, u32 + 10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16 + 2 | 3 | 9 => 1, // u8, i8, bool + 8 => 4, // string(1~4) + _ => 1, // TODO: others + } + } + + #[allow(dead_code)] + fn nbytes_from_onnx_dtype(x: &TensorElementType) -> usize { + match x { + TensorElementType::Float64 | TensorElementType::Uint64 | TensorElementType::Int64 => 8, // i64, f64, u64 + TensorElementType::Float32 + | TensorElementType::Uint32 + | TensorElementType::Int32 + | TensorElementType::String => 4, // f32, i32, u32, string(1~4) + TensorElementType::Float16 + | TensorElementType::Bfloat16 + | TensorElementType::Int16 + | TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 + TensorElementType::Uint8 | TensorElementType::Int8 | TensorElementType::Bool => 1, // u8, i8, bool + } + } + + #[allow(dead_code)] + fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { + match value { + 0 => None, + 1 => Some(TensorElementType::Float32), + 2 => Some(TensorElementType::Uint8), + 3 => Some(TensorElementType::Int8), + 4 => Some(TensorElementType::Uint16), + 5 => Some(TensorElementType::Int16), + 6 => Some(TensorElementType::Int32), + 7 => Some(TensorElementType::Int64), + 8 => Some(TensorElementType::String), + 9 => Some(TensorElementType::Bool), + 10 => Some(TensorElementType::Float16), + 11 => Some(TensorElementType::Float64), + 12 => Some(TensorElementType::Uint32), + 13 => Some(TensorElementType::Uint64), + 14 => None, // COMPLEX64 + 15 => None, // COMPLEX128 + 16 => Some(TensorElementType::Bfloat16), + _ => None, + } + } + + fn io_from_onnx_value_info( + initializer_names: &HashSet<&str>, + value_info: &[onnx::ValueInfoProto], + ) -> Result { + let mut dimss: Vec> = Vec::new(); + let mut dtypes: Vec = Vec::new(); + let mut names: Vec = Vec::new(); + for v in value_info.iter() { + if initializer_names.contains(v.name.as_str()) { + continue; + } + names.push(v.name.to_string()); + let dtype = match &v.r#type { + Some(dtype) => dtype, + None => continue, + }; + let dtype = match &dtype.value { + Some(dtype) => dtype, + None => continue, + }; + let tensor = match dtype { + onnx::type_proto::Value::TensorType(tensor) => tensor, + _ => continue, + }; + let tensor_type = tensor.elem_type; + let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { + Some(dtype) => dtype, + None => continue, + }; + dtypes.push(tensor_type); + + let shapes = match &tensor.shape { + Some(shapes) => shapes, + None => continue, + }; + let mut shape_: Vec = Vec::new(); + for shape in shapes.dim.iter() { + match &shape.value { + None => continue, + Some(value) => match value { + onnx::tensor_shape_proto::dimension::Value::DimValue(x) => { + shape_.push(*x as _); + } + onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { + shape_.push(0); + } + }, + } + } + dimss.push(shape_); + } + Ok(OrtTensorAttr { + dimss, + dtypes, + names, + }) + } + + pub fn load_onnx>(p: P) -> Result { + let f = std::fs::read(p.as_ref())?; + onnx::ModelProto::decode(f.as_slice()).map_err(|err| { + anyhow::anyhow!( + "Failed to read the ONNX model: The file might be incomplete or corrupted. More detailed: {}", + err + ) + }) + } + + pub fn batch(&self) -> &MinOptMax { + &self.inputs_minoptmax[0][0] + } + + pub fn is_batch_dyn(&self) -> bool { + self.batch().is_dyn() + } + + pub fn try_height(&self) -> Option<&MinOptMax> { + self.inputs_minoptmax.first().and_then(|x| x.get(2)) + } + + pub fn height(&self) -> &MinOptMax { + // unsafe + &self.inputs_minoptmax[0][2] + } + + pub fn is_height_dyn(&self) -> bool { + self.height().is_dyn() + } + + pub fn try_width(&self) -> Option<&MinOptMax> { + self.inputs_minoptmax.first().and_then(|x| x.get(3)) + } + + pub fn width(&self) -> &MinOptMax { + // unsafe + &self.inputs_minoptmax[0][3] + } + + pub fn is_width_dyn(&self) -> bool { + self.width().is_dyn() + } + + pub fn try_fetch(&self, key: &str) -> Option { + match self.onnx.as_ref().unwrap().session.metadata() { + Err(_) => None, + Ok(metadata) => metadata.custom(key).unwrap_or_default(), + } + } + + pub fn ir_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.ir_version as usize) + } + + pub fn opset_version(&self) -> Option { + self.onnx + .as_ref() + .map(|x| x.proto.opset_import[0].version as usize) + } + + pub fn producer_name(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.producer_name.clone()) + } + + pub fn producer_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.producer_version.clone()) + } + + pub fn model_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.model_version as usize) + } + + pub fn ishapes(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.inputs.dimss()) + } + + pub fn idimss(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.inputs.dimss()) + } + + pub fn inames(&self) -> Option<&[String]> { + self.onnx.as_ref().map(|x| x.inputs.names()) + } + + pub fn idtypes(&self) -> Option> { + self.onnx.as_ref().and_then(|x| { + x.inputs + .dtypes() + .iter() + .map(DType::from_ort) + .collect::>() + .into() + }) + } + + pub fn oshapes(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.outputs.dimss()) + } + + pub fn odimss(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.outputs.dimss()) + } + + pub fn onames(&self) -> Option<&[String]> { + self.onnx.as_ref().map(|x| x.outputs.names()) + } + + pub fn odtypes(&self) -> Option> { + self.onnx.as_ref().and_then(|x| { + x.outputs + .dtypes() + .iter() + .map(DType::from_ort) + .collect::>() + .into() + }) + } + + pub fn profile(&self) { + self.ts.summary(); + } + + pub fn info(&self) { + let info = format!( + "Minimum Supported Ort Version: 1.{}.x, Opset Version: {}, Device: {}, Parameters: {}, Memory: {}", + ort::MINOR_VERSION, + self.opset_version().map_or("Unknown".to_string(), |x| x.to_string()), + self.device, + match self.params { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + match self.wbmems { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + ); + + info!("{}", info); + } +} diff --git a/src/misc/hub.rs b/src/misc/hub.rs new file mode 100644 index 0000000..acb2a43 --- /dev/null +++ b/src/misc/hub.rs @@ -0,0 +1,521 @@ +use anyhow::{Context, Result}; +use indicatif::{ProgressBar, ProgressStyle}; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use crate::{retry, Dir, PREFIX_LENGTH}; + +/// Represents a downloadable asset in a release +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Asset { + pub name: String, + pub browser_download_url: String, + pub size: u64, +} + +/// Represents a GitHub release +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Release { + pub tag_name: String, + pub assets: Vec, +} + +// / Manages interactions with a GitHub repository's releases +/// Provides an interface for managing GitHub releases, including downloading assets, +/// fetching release tags and file information, and handling caching. +/// +/// The `Hub` struct simplifies interactions with a GitHub repository by allowing users +/// to specify a repository owner and name, download files from releases, and manage +/// cached data to reduce redundant network requests. +/// +/// # Fields +/// - `owner`: The owner of the GitHub repository (e.g., `"jamjamjon"`). +/// - `repo`: The name of the GitHub repository (e.g., `"assets"`). +/// - `to`: The directory where downloaded files are stored, determined from a prioritized list +/// of available directories (e.g., cache, home, config, or current directory). +/// - `timeout`: Timeout duration for network requests, in seconds. +/// - `ttl`: Time-to-live duration for cached data, defining how long cache files remain valid. +/// - `max_attempts`: The maximum number of retry attempts for failed downloads or network operations. +/// +/// # Example +/// +/// ## 1. Download from a default GitHub release +/// Download a file by specifying its path relative to the release: +/// ```rust,ignore +/// let path = usls::Hub::default().try_fetch("images/bus.jpg")?; +/// println!("Fetched image to: {:?}", path); +/// ``` +/// +/// ## 2. Download from a specific GitHub release URL +/// Fetch a file directly using its full GitHub release URL: +/// ```rust,ignore +/// let path = usls::Hub::default() +/// .try_fetch("https://github.com/jamjamjon/assets/releases/download/images/bus.jpg")?; +/// println!("Fetched file to: {:?}", path); +/// ``` +/// +/// ## 3. Fetch available tags and files in a repository +/// List all release tags and the files associated with each tag: +/// ```rust,ignore +/// let hub = usls::Hub::default().with_owner("jamjamjon").with_repo("usls"); +/// for tag in hub.tags().iter() { +/// let files = hub.files(tag); +/// println!("Tag: {}, Files: {:?}", tag, files); +/// } +/// ``` +/// +/// # Default Behavior +/// By default, `Hub` interacts with the `jamjamjon/assets` repository, stores downloads in +/// an accessible directory, and applies a 10-minute cache expiration time. These settings +/// can be customized using the builder-like methods `with_owner`, `with_repo`, `with_ttl`, +/// `with_timeout`, and `with_max_attempts`. +/// +/// # Errors +/// Methods in `Hub` return `Result` types. Errors may occur due to invalid paths, failed +/// network requests, cache write failures, or mismatched file sizes during downloads. +/// +#[derive(Debug)] +pub struct Hub { + /// GitHub repository owner + owner: String, + + /// GitHub repository name + repo: String, + + /// Directory to store the downloaded file + to: Dir, + + /// Download timeout in seconds + timeout: u64, + + /// Time to live (cache duration) + ttl: Duration, + + /// The maximum number of retry attempts for failed downloads or network operations + max_attempts: u32, +} + +impl Default for Hub { + fn default() -> Self { + let owner = "jamjamjon".to_string(); + let repo = "assets".to_string(); + let to = [Dir::Cache, Dir::Home, Dir::Config, Dir::Current] + .into_iter() + .find(|dir| dir.path().is_ok()) + .expect( + "Unable to get cache directory, home directory, config directory, and current directory. Possible reason: \ + \n1. Unsupported OS \ + \n2. Directory does not exist \ + \n3. Insufficient permissions to access" + ); + + Self { + owner, + repo, + to, + timeout: 3000, + max_attempts: 3, + ttl: Duration::from_secs(10 * 60), + } + } +} + +impl Hub { + pub fn new(owner: &str, repo: &str) -> Self { + Self { + owner: owner.into(), + repo: repo.into(), + ..Default::default() + } + } + + /// Attempts to fetch a file from a local path or a GitHub release. + /// + /// The `try_fetch` method supports three main scenarios: + /// 1. **Local file**: If the provided string is a valid file path, the file is returned without downloading. + /// 2. **GitHub release URL**: If the input matches a valid GitHub release URL, the corresponding file is downloaded. + /// 3. **Default repository**: If no explicit URL is provided, the method uses the default or configured repository. + /// + /// # Parameters + /// - `s`: A string representing the file to fetch. This can be: + /// - A local file path. + /// - A GitHub release URL (e.g., `https://github.com/owner/repo/releases/download/tag/file`). + /// - A `/` format for fetching from the default repository. + /// + /// # Returns + /// - `Result`: On success, returns the path to the fetched file. + /// + /// # Errors + /// - Returns an error if: + /// - The file cannot be found locally. + /// - The URL or tag is invalid. + /// - Network operations fail after the maximum retry attempts. + /// + /// # Example + /// ```rust,ignore + /// let mut hub = Hub::default(); + /// + /// // Fetch a file from a local path + /// let local_path = hub.try_fetch("local/path/to/file").expect("File not found"); + /// + /// // Fetch a file from a GitHub release URL + /// let url_path = hub.try_fetch("https://github.com/owner/repo/releases/download/tag/file") + /// .expect("Failed to fetch file"); + /// + /// // Fetch a file using the default repository + /// let default_repo_path = hub.try_fetch("v1.0.0/file").expect("Failed to fetch file"); + /// ``` + pub fn try_fetch(&mut self, s: &str) -> Result { + #[derive(Default, Debug, aksr::Builder)] + struct Pack { + // owner: String, + // repo: String, + url: String, + tag: String, + file_name: String, + file_size: Option, + } + let mut pack = Pack::default(); + + // saveout + let p = PathBuf::from(s); + let saveout = if p.exists() { + // => Local file + p + } else if let Some((owner_, repo_, tag_, file_name_)) = Self::is_valid_github_release_url(s) + { + // => Valid GitHub release URL + // keep original owner, repo and tag + let saveout = self + .to + .path_with_subs(&[&owner_, &repo_, &tag_])? + .join(&file_name_); + + pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_); + if let Some(n) = retry!(self.max_attempts, Self::fetch_get_response(s))? + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + { + pack = pack.with_file_size(n); + } + + saveout + } else { + // => Default hub + + // Fetch releases + let releases = match Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) { + Err(err) => anyhow::bail!( + "Failed to download: No releases found in this repo. Error: {}", + err + ), + Ok(releases) => releases, + }; + + // Check remote + match s.split_once('/') { + Some((tag_, file_name_)) => { + // Validate the tag + let tags: Vec = releases.iter().map(|x| x.tag_name.clone()).collect(); + if !tags.contains(&tag_.to_string()) { + anyhow::bail!( + "Failed to download: Tag `{}` not found in GitHub releases. Available tags: {:?}", + tag_, + tags + ); + } else { + // Validate the file + if let Some(release) = releases.iter().find(|r| r.tag_name == tag_) { + let files: Vec<&str> = + release.assets.iter().map(|x| x.name.as_str()).collect(); + if !files.contains(&file_name_) { + anyhow::bail!( + "Failed to download: The file `{}` is missing in tag `{}`. Available files: {:?}", + file_name_, + tag_, + files + ); + } else { + for f_ in release.assets.iter() { + if f_.name.as_str() == file_name_ { + pack = pack.with_url(&f_.browser_download_url).with_tag(tag_).with_file_name(file_name_).with_file_size(f_.size); + break; + } + } + } + } + + self.to.path_with_subs(&[tag_])?.join(file_name_) + } + } + _ => anyhow::bail!( + "Failed to download file from github releases due to invalid format. Expected: /, got: {}", + s + ), + } + }; + + // Commit the downloaded file, downloading if necessary + if !pack.url.is_empty() { + // Download if the file does not exist or if the size of file does not match + if saveout.is_file() { + match pack.file_size { + None => { + log::warn!("Failed to retrieve the remote file size. \ + Download will be skipped, which may cause issues. \ + Please verify your network connection or ensure the local file is valid and complete." + ); + } + Some(file_size) => { + if std::fs::metadata(&saveout)?.len() != file_size { + log::debug!( + "Local file size does not match remote. Starting download." + ); + retry!( + self.max_attempts, + 1000, + 3000, + Self::download( + &pack.url, + &saveout, + Some(&format!("{}/{}", pack.tag, pack.file_name)), + ) + )?; + } else { + log::debug!("Local file size matches remote. No download required."); + } + } + } + } else { + log::debug!("Starting remote file download..."); + retry!( + self.max_attempts, + 1000, + 3000, + Self::download( + &pack.url, + &saveout, + Some(&format!("{}/{}", pack.tag, pack.file_name)), + ) + )?; + } + } + + saveout + .to_str() + .map(|s| s.to_string()) + .with_context(|| format!("Failed to convert PathBuf: {:?} to String", saveout)) + } + + /// Fetch releases from GitHub and cache them + fn fetch_and_cache_releases(url: &str, cache_path: &Path) -> Result { + let response = retry!(3, Self::fetch_get_response(url))?; + let body = response + .into_string() + .context("Failed to read response body")?; + + // Ensure cache directory exists + let parent_dir = cache_path + .parent() + .context("Invalid cache path: no parent directory found")?; + std::fs::create_dir_all(parent_dir) + .with_context(|| format!("Failed to create cache directory: {:?}", parent_dir))?; + + // Create temporary file + let mut temp_file = tempfile::NamedTempFile::new_in(parent_dir) + .context("Failed to create temporary cache file")?; + + // Write data to temporary file + temp_file + .write_all(body.as_bytes()) + .context("Failed to write to temporary cache file")?; + + // Persist temporary file as the cache + temp_file.persist(cache_path).with_context(|| { + format!("Failed to persist temporary cache file to {:?}", cache_path) + })?; + + Ok(body) + } + + pub fn tags(&self) -> Vec { + Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + .unwrap_or_default() + .into_iter() + .map(|x| x.tag_name) + .collect() + } + + pub fn files(&self, tag: &str) -> Vec { + Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + .unwrap_or_default() + .into_iter() + .find(|r| r.tag_name == tag) + .map(|a| a.assets.iter().map(|x| x.name.clone()).collect()) + .unwrap_or_default() + } + + pub fn is_file_expired>(file: P, ttl: &Duration) -> Result { + let file = file.as_ref(); + let y = if !file.exists() { + log::debug!("No cache found, fetching data from GitHub"); + true + } else { + match std::fs::metadata(file)?.modified() { + Err(_) => { + log::debug!("Cannot get file modified time, fetching new data from GitHub"); + true + } + Ok(modified_time) => { + if std::time::SystemTime::now().duration_since(modified_time)? < *ttl { + log::debug!("Using cached data"); + false + } else { + log::debug!("Cache expired, fetching new data from GitHub"); + true + } + } + } + }; + Ok(y) + } + + /// Download a file from a github release to a specified path with a progress bar + pub fn download + std::fmt::Debug>( + src: &str, + dst: P, + message: Option<&str>, + ) -> Result<()> { + let resp = Self::fetch_get_response(src)?; + let ntotal = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .context("Content-Length header is missing or invalid")?; + + let pb = ProgressBar::new(ntotal); + pb.set_style( + ProgressStyle::with_template( + "{prefix:.cyan.bold} {msg} |{bar}| ({percent_precise}%, {binary_bytes}/{binary_total_bytes}, {binary_bytes_per_sec})", + )? + .progress_chars("██ "), + ); + pb.set_prefix(format!("{:>PREFIX_LENGTH$}", "Fetching")); + pb.set_message(message.unwrap_or_default().to_string()); + + let mut reader = resp.into_reader(); + let mut buffer = [0; 512]; + let mut downloaded_bytes = 0usize; + let mut file = std::fs::File::create(&dst) + .with_context(|| format!("Failed to create destination file: {:?}", dst))?; + + loop { + let bytes_read = reader.read(&mut buffer)?; + if bytes_read == 0 { + break; + } + file.write_all(&buffer[..bytes_read]) + .context("Failed to write to file")?; + downloaded_bytes += bytes_read; + pb.inc(bytes_read as u64); + } + + // check size + if downloaded_bytes as u64 != ntotal { + anyhow::bail!("The downloaded file is incomplete."); + } + + // update + pb.set_prefix("Downloaded"); + pb.set_style(ProgressStyle::with_template( + crate::PROGRESS_BAR_STYLE_FINISH_3, + )?); + pb.finish(); + + Ok(()) + } + + fn fetch_get_response(url: &str) -> anyhow::Result { + let response = ureq::get(url) + .call() + .map_err(|err| anyhow::anyhow!("Failed to GET response from {}: {}", url, err))?; + + if response.status() != 200 { + anyhow::bail!("Failed to fetch data from remote due to: {:?}", response); + } + + Ok(response) + } + + fn cache_file(owner: &str, repo: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(format!("{}-{}", owner, repo)); + format!(".{:x}", hasher.finalize()) + } + + fn get_releases(owner: &str, repo: &str, to: &Dir, ttl: &Duration) -> Result> { + let cache = to.path()?.join(Self::cache_file(owner, repo)); + let is_file_expired = Self::is_file_expired(&cache, ttl)?; + let body = if is_file_expired { + let gh_api_release = + format!("https://api.github.com/repos/{}/{}/releases", owner, repo); + Self::fetch_and_cache_releases(&gh_api_release, &cache)? + } else { + std::fs::read_to_string(&cache)? + }; + + Ok(serde_json::from_str(&body)?) + } + + pub(crate) fn is_valid_github_release_url( + url: &str, + ) -> Option<(String, String, String, String)> { + let re = + Regex::new(r"^https://github\.com/([^/]+)/([^/]+)/releases/download/([^/]+)/([^/]+)$") + .expect("Failed to compile the regex for GitHub release URL pattern"); + + if let Some(caps) = re.captures(url) { + let owner = caps.get(1).map_or("", |m| m.as_str()); + let repo = caps.get(2).map_or("", |m| m.as_str()); + let tag = caps.get(3).map_or("", |m| m.as_str()); + let file = caps.get(4).map_or("", |m| m.as_str()); + + Some(( + owner.to_string(), + repo.to_string(), + tag.to_string(), + file.to_string(), + )) + } else { + None + } + } + + pub fn with_owner(mut self, owner: &str) -> Self { + self.owner = owner.to_string(); + self + } + + pub fn with_repo(mut self, repo: &str) -> Self { + self.repo = repo.to_string(); + self + } + + pub fn with_ttl(mut self, x: u64) -> Self { + self.ttl = std::time::Duration::from_secs(x); + self + } + + pub fn with_timeout(mut self, x: u64) -> Self { + self.timeout = x; + self + } + + pub fn with_max_attempts(mut self, x: u32) -> Self { + self.max_attempts = x; + self + } +} diff --git a/src/misc/iiix.rs b/src/misc/iiix.rs new file mode 100644 index 0000000..6db1626 --- /dev/null +++ b/src/misc/iiix.rs @@ -0,0 +1,15 @@ +use crate::MinOptMax; + +/// A struct for input composed of the i-th input, the ii-th dimension, and the value. +#[derive(Clone, Debug, Default)] +pub struct Iiix { + pub i: usize, + pub ii: usize, + pub x: MinOptMax, +} + +impl From<(usize, usize, MinOptMax)> for Iiix { + fn from((i, ii, x): (usize, usize, MinOptMax)) -> Self { + Self { i, ii, x } + } +} diff --git a/src/misc/kind.rs b/src/misc/kind.rs new file mode 100644 index 0000000..4519427 --- /dev/null +++ b/src/misc/kind.rs @@ -0,0 +1,18 @@ +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub enum Kind { + // Do we really need this? + Vision, + Language, + VisionLanguage, +} + +impl std::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Vision => "visual", + Self::Language => "textual", + Self::VisionLanguage => "vl", + }; + write!(f, "{}", x) + } +} diff --git a/src/misc/labels.rs b/src/misc/labels.rs new file mode 100644 index 0000000..0415615 --- /dev/null +++ b/src/misc/labels.rs @@ -0,0 +1,1155 @@ +pub const COCO_SKELETONS_16: [(usize, usize); 16] = [ + (0, 1), + (0, 2), + (1, 3), + (2, 4), + (5, 6), + (5, 11), + (6, 12), + (11, 12), + (5, 7), + (6, 8), + (7, 9), + (8, 10), + (11, 13), + (12, 14), + (13, 15), + (14, 16), +]; + +pub const COCO_KEYPOINTS_NAMES_17: [&str; 17] = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", +]; + +pub const COCO_CLASS_NAMES_80: [&str; 80] = [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +]; + +pub const BODY_PARTS_NAMES_28: [&str; 28] = [ + "Background", + "Apparel", + "Face Neck", + "Hair", + "Left Foot", + "Left Hand", + "Left Lower Arm", + "Left Lower Leg", + "Left Shoe", + "Left Sock", + "Left Upper Arm", + "Left Upper Leg", + "Lower Clothing", + "Right Foot", + "Right Hand", + "Right Lower Arm", + "Right Lower Leg", + "Right Shoe", + "Right Sock", + "Right Upper Arm", + "Right Upper Leg", + "Torso", + "Upper Clothing", + "Lower Lip", + "Upper Lip", + "Lower Teeth", + "Upper Teeth", + "Tongue", +]; + +pub const IMAGENET_NAMES_1K: [&str; 1000] = [ + "tench, Tinca tinca", + "goldfish, Carassius auratus", + "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "tiger shark, Galeocerdo cuvieri", + "hammerhead, hammerhead shark", + "electric ray, crampfish, numbfish, torpedo", + "stingray", + "cock", + "hen", + "ostrich, Struthio camelus", + "brambling, Fringilla montifringilla", + "goldfinch, Carduelis carduelis", + "house finch, linnet, Carpodacus mexicanus", + "junco, snowbird", + "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "robin, American robin, Turdus migratorius", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel, dipper", + "kite", + "bald eagle, American eagle, Haliaeetus leucocephalus", + "vulture", + "great grey owl, great gray owl, Strix nebulosa", + "European fire salamander, Salamandra salamandra", + "common newt, Triturus vulgaris", + "eft", + "spotted salamander, Ambystoma maculatum", + "axolotl, mud puppy, Ambystoma mexicanum", + "bullfrog, Rana catesbeiana", + "tree frog, tree-frog", + "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "loggerhead, loggerhead turtle, Caretta caretta", + "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "mud turtle", + "terrapin", + "box turtle, box tortoise", + "banded gecko", + "common iguana, iguana, Iguana iguana", + "American chameleon, anole, Anolis carolinensis", + "whiptail, whiptail lizard", + "agama", + "frilled lizard, Chlamydosaurus kingi", + "alligator lizard", + "Gila monster, Heloderma suspectum", + "green lizard, Lacerta viridis", + "African chameleon, Chamaeleo chamaeleon", + "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "African crocodile, Nile crocodile, Crocodylus niloticus", + "American alligator, Alligator mississipiensis", + "triceratops", + "thunder snake, worm snake, Carphophis amoenus", + "ringneck snake, ring-necked snake, ring snake", + "hognose snake, puff adder, sand viper", + "green snake, grass snake", + "king snake, kingsnake", + "garter snake, grass snake", + "water snake", + "vine snake", + "night snake, Hypsiglena torquata", + "boa constrictor, Constrictor constrictor", + "rock python, rock snake, Python sebae", + "Indian cobra, Naja naja", + "green mamba", + "sea snake", + "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "sidewinder, horned rattlesnake, Crotalus cerastes", + "trilobite", + "harvestman, daddy longlegs, Phalangium opilio", + "scorpion", + "black and gold garden spider, Argiope aurantia", + "barn spider, Araneus cavaticus", + "garden spider, Aranea diademata", + "black widow, Latrodectus mactans", + "tarantula", + "wolf spider, hunting spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse, partridge, Bonasa umbellus", + "prairie chicken, prairie grouse, prairie fowl", + "peacock", + "quail", + "partridge", + "African grey, African gray, Psittacus erithacus", + "macaw", + "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser, Mergus serrator", + "goose", + "black swan, Cygnus atratus", + "tusker", + "echidna, spiny anteater, anteater", + "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "wallaby, brush kangaroo", + "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "wombat", + "jellyfish", + "sea anemone, anemone", + "brain coral", + "flatworm, platyhelminth", + "nematode, nematode worm, roundworm", + "conch", + "snail", + "slug", + "sea slug, nudibranch", + "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "chambered nautilus, pearly nautilus, nautilus", + "Dungeness crab, Cancer magister", + "rock crab, Cancer irroratus", + "fiddler crab", + "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "crayfish, crawfish, crawdad, crawdaddy", + "hermit crab", + "isopod", + "white stork, Ciconia ciconia", + "black stork, Ciconia nigra", + "spoonbill", + "flamingo", + "little blue heron, Egretta caerulea", + "American egret, great white heron, Egretta albus", + "bittern", + "crane", + "limpkin, Aramus pictus", + "European gallinule, Porphyrio porphyrio", + "American coot, marsh hen, mud hen, water hen, Fulica americana", + "bustard", + "ruddy turnstone, Arenaria interpres", + "red-backed sandpiper, dunlin, Erolia alpina", + "redshank, Tringa totanus", + "dowitcher", + "oystercatcher, oyster catcher", + "pelican", + "king penguin, Aptenodytes patagonica", + "albatross, mollymawk", + "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "dugong, Dugong dugon", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog, Maltese terrier, Maltese", + "Pekinese, Pekingese, Peke", + "shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound, Afghan", + "basset, basset hound", + "beagle", + "bloodhound, sleuthhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound, Walker foxhound", + "English foxhound", + "redbone", + "borzoi, Russian wolfhound", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound, Ibizan Podenco", + "Norwegian elkhound, elkhound", + "otterhound, otter hound", + "saluki, gazelle hound", + "scottish deerhound, deerhound", + "Weimaraner", + "staffordshire bullterrier, Staffordshire bull terrier", + "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "sealyham terrier, Sealyham", + "Airedale, Airedale terrier", + "cairn, cairn terrier", + "Australian terrier", + "Dandie Dinmont, Dandie Dinmont terrier", + "Boston bull, Boston terrier", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "scotch terrier, Scottish terrier, Scottie", + "Tibetan terrier, chrysanthemum dog", + "silky terrier, Sydney silky", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa, Lhasa apso", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla, Hungarian pointer", + "English setter", + "Irish setter, red setter", + "Gordon setter", + "Brittany spaniel", + "clumber, clumber spaniel", + "English springer, English springer spaniel", + "Welsh springer spaniel", + "cocker spaniel, English cocker spaniel, cocker", + "sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog, bobtail", + "shetland sheepdog, Shetland sheep dog, Shetland", + "collie", + "Border collie", + "Bouvier des Flandres, Bouviers des Flandres", + "Rottweiler", + "German shepherd, German shepherd dog, German police dog, alsatian", + "Doberman, Doberman pinscher", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "saint Bernard, St Bernard", + "Eskimo dog, husky", + "malamute, malemute, Alaskan malamute", + "siberian husky", + "dalmatian, coach dog, carriage dog", + "affenpinscher, monkey pinscher, monkey dog", + "basenji", + "pug, pug-dog", + "Leonberg", + "Newfoundland, Newfoundland dog", + "Great Pyrenees", + "samoyed, Samoyede", + "Pomeranian", + "chow, chow chow", + "keeshond", + "Brabancon griffon", + "Pembroke, Pembroke Welsh corgi", + "Cardigan, Cardigan Welsh corgi", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf, grey wolf, gray wolf, Canis lupus", + "white wolf, Arctic wolf, Canis lupus tundrarum", + "red wolf, maned wolf, Canis rufus, Canis niger", + "coyote, prairie wolf, brush wolf, Canis latrans", + "dingo, warrigal, warragal, Canis dingo", + "dhole, Cuon alpinus", + "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "hyena, hyaena", + "red fox, Vulpes vulpes", + "kit fox, Vulpes macrotis", + "Arctic fox, white fox, Alopex lagopus", + "grey fox, gray fox, Urocyon cinereoargenteus", + "tabby, tabby cat", + "tiger cat", + "Persian cat", + "siamese cat, Siamese", + "Egyptian cat", + "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "lynx, catamount", + "leopard, Panthera pardus", + "snow leopard, ounce, Panthera uncia", + "jaguar, panther, Panthera onca, Felis onca", + "lion, king of beasts, Panthera leo", + "tiger, Panthera tigris", + "cheetah, chetah, Acinonyx jubatus", + "brown bear, bruin, Ursus arctos", + "American black bear, black bear, Ursus americanus, Euarctos americanus", + "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "sloth bear, Melursus ursinus, Ursus ursinus", + "mongoose", + "meerkat, mierkat", + "tiger beetle", + "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "ground beetle, carabid beetle", + "long-horned beetle, longicorn, longicorn beetle", + "leaf beetle, chrysomelid", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant, emmet, pismire", + "grasshopper, hopper", + "cricket", + "walking stick, walkingstick, stick insect", + "cockroach, roach", + "mantis, mantid", + "cicada, cicala", + "leafhopper", + "lacewing, lacewing fly", + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "damselfly", + "admiral", + "ringlet, ringlet butterfly", + "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "cabbage butterfly", + "sulphur butterfly, sulfur butterfly", + "lycaenid, lycaenid butterfly", + "starfish, sea star", + "sea urchin", + "sea cucumber, holothurian", + "wood rabbit, cottontail, cottontail rabbit", + "hare", + "Angora, Angora rabbit", + "hamster", + "porcupine, hedgehog", + "fox squirrel, eastern fox squirrel, Sciurus niger", + "marmot", + "beaver", + "guinea pig, Cavia cobaya", + "sorrel", + "zebra", + "hog, pig, grunter, squealer, Sus scrofa", + "wild boar, boar, Sus scrofa", + "warthog", + "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "ox", + "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "bison", + "ram, tup", + "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "ibex, Capra ibex", + "hartebeest", + "impala, Aepyceros melampus", + "gazelle", + "Arabian camel, dromedary, Camelus dromedarius", + "llama", + "weasel", + "mink", + "polecat, fitch, foulmart, foumart, Mustela putorius", + "black-footed ferret, ferret, Mustela nigripes", + "otter", + "skunk, polecat, wood pussy", + "badger", + "armadillo", + "three-toed sloth, ai, Bradypus tridactylus", + "orangutan, orang, orangutang, Pongo pygmaeus", + "gorilla, Gorilla gorilla", + "chimpanzee, chimp, Pan troglodytes", + "gibbon, Hylobates lar", + "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "guenon, guenon monkey", + "patas, hussar monkey, Erythrocebus patas", + "baboon", + "macaque", + "langur", + "colobus, colobus monkey", + "proboscis monkey, Nasalis larvatus", + "marmoset", + "capuchin, ringtail, Cebus capucinus", + "howler monkey, howler", + "titi, titi monkey", + "spider monkey, Ateles geoffroyi", + "squirrel monkey, Saimiri sciureus", + "Madagascar cat, ring-tailed lemur, Lemur catta", + "indri, indris, Indri indri, Indri brevicaudatus", + "Indian elephant, Elephas maximus", + "African elephant, Loxodonta africana", + "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "barracouta, snoek", + "eel", + "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "rock beauty, Holocanthus tricolor", + "anemone fish", + "sturgeon", + "gar, garfish, garpike, billfish, Lepisosteus osseus", + "lionfish", + "puffer, pufferfish, blowfish, globefish", + "abacus", + "abaya", + "academic gown, academic robe, judge's robe", + "accordion, piano accordion, squeeze box", + "acoustic guitar", + "aircraft carrier, carrier, flattop, attack aircraft carrier", + "airliner", + "airship, dirigible", + "altar", + "ambulance", + "amphibian, amphibious vehicle", + "analog clock", + "apiary, bee house", + "apron", + "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "assault rifle, assault gun", + "backpack, back pack, knapsack, packsack, rucksack, haversack", + "bakery, bakeshop, bakehouse", + "balance beam, beam", + "balloon", + "ballpoint, ballpoint pen, ballpen, Biro", + "Band Aid", + "banjo", + "bannister, banister, balustrade, balusters, handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel, cask", + "barrow, garden cart, lawn cart, wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap, swimming cap", + "bath towel", + "bathtub, bathing tub, bath, tub", + "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "beacon, lighthouse, beacon light, pharos", + "beaker", + "bearskin, busby, shako", + "beer bottle", + "beer glass", + "bell cote, bell cot", + "bib", + "bicycle-built-for-two, tandem bicycle, tandem", + "bikini, two-piece", + "binder, ring-binder", + "binoculars, field glasses, opera glasses", + "birdhouse", + "boathouse", + "bobsled, bobsleigh, bob", + "bolo tie, bolo, bola tie, bola", + "bonnet, poke bonnet", + "bookcase", + "bookshop, bookstore, bookstall", + "bottlecap", + "bow", + "bow tie, bow-tie, bowtie", + "brass, memorial tablet, plaque", + "brassiere, bra, bandeau", + "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "breastplate, aegis, egis", + "broom", + "bucket, pail", + "buckle", + "bulletproof vest", + "bullet train, bullet", + "butcher shop, meat market", + "cab, hack, taxi, taxicab", + "caldron, cauldron", + "candle, taper, wax light", + "cannon", + "canoe", + "can opener, tin opener", + "cardigan", + "car mirror", + "carousel, carrousel, merry-go-round, roundabout, whirligig", + "carpenter's kit, tool kit", + "carton", + "car wheel", + "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello, violoncello", + "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "chain", + "chainlink fence", + "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "chain saw, chainsaw", + "chest", + "chiffonier, commode", + "chime, bell, gong", + "china cabinet, china closet", + "Christmas stocking", + "church, church building", + "cinema, movie theater, movie theatre, movie house, picture palace", + "cleaver, meat cleaver, chopper", + "cliff dwelling", + "cloak", + "clog, geta, patten, sabot", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil, spiral, volute, whorl, helix", + "combination lock", + "computer keyboard, keypad", + "confectionery, confectionary, candy store", + "container ship, containership, container vessel", + "convertible", + "corkscrew, bottle screw", + "cornet, horn, trumpet, trump", + "cowboy boot", + "cowboy hat, ten-gallon hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib, cot", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam, dike, dyke", + "desk", + "desktop computer", + "dial telephone, dial phone", + "diaper, nappy, napkin", + "digital clock", + "digital watch", + "dining table, board", + "dishrag, dishcloth", + "dishwasher, dish washer, dishwashing machine", + "disk brake, disc brake", + "dock, dockage, docking facility", + "dogsled, dog sled, dog sleigh", + "dome", + "doormat, welcome mat", + "drilling platform, offshore rig", + "drum, membranophone, tympan", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan, blower", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa, boa", + "file, file cabinet, filing cabinet", + "fireboat", + "fire engine, fire truck", + "fire screen, fireguard", + "flagpole, flagstaff", + "flute, transverse flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn, horn", + "frying pan, frypan, skillet", + "fur coat", + "garbage truck, dustcart", + "gasmask, respirator, gas helmet", + "gas pump, gasoline pump, petrol pump, island dispenser", + "goblet", + "go-kart", + "golf ball", + "golfcart, golf cart", + "gondola", + "gong, tam-tam", + "gown", + "grand piano, grand", + "greenhouse, nursery, glasshouse", + "grille, radiator grille", + "grocery store, grocery, food market, market", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "hand-held computer, hand-held microcomputer", + "handkerchief, hankie, hanky, hankey", + "hard disc, hard disk, fixed disk", + "harmonica, mouth organ, harp, mouth harp", + "harp", + "harvester, reaper", + "hatchet", + "holster", + "home theater, home theatre", + "honeycomb", + "hook, claw", + "hoopskirt, crinoline", + "horizontal bar, high bar", + "horse cart, horse-cart", + "hourglass", + "iPod", + "iron, smoothing iron", + "jack-o'-lantern", + "jean, blue jean, denim", + "jeep, landrover", + "jersey, T-shirt, tee shirt", + "jigsaw puzzle", + "jinrikisha, ricksha, rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat, laboratory coat", + "ladle", + "lampshade, lamp shade", + "laptop, laptop computer", + "lawn mower, mower", + "lens cap, lens cover", + "letter opener, paper knife, paperknife", + "library", + "lifeboat", + "lighter, light, igniter, ignitor", + "limousine, limo", + "liner, ocean liner", + "lipstick, lip rouge", + "Loafer", + "lotion", + "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "loupe, jeweler's loupe", + "lumbermill, sawmill", + "magnetic compass", + "mailbag, postbag", + "mailbox, letter box", + "maillot", + "maillot, tank suit", + "manhole cover", + "maraca", + "marimba, xylophone", + "mask", + "matchstick", + "maypole", + "maze, labyrinth", + "measuring cup", + "medicine chest, medicine cabinet", + "megalith, megalithic structure", + "microphone, mike", + "microwave, microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt, mini", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home, manufactured home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter, scooter", + "mountain bike, all-terrain bike, off-roader", + "mountain tent", + "mouse, computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook, notebook computer", + "obelisk", + "oboe, hautboy, hautbois", + "ocarina, sweet potato", + "odometer, hodometer, mileometer, milometer", + "oil filter", + "organ, pipe organ", + "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle, boat paddle", + "paddlewheel, paddle wheel", + "padlock", + "paintbrush", + "pajama, pyjama, pj's, jammies", + "palace", + "panpipe, pandean pipe, syrinx", + "paper towel", + "parachute, chute", + "parallel bars, bars", + "park bench", + "parking meter", + "passenger car, coach, carriage", + "patio, terrace", + "pay-phone, pay-station", + "pedestal, plinth, footstall", + "pencil box, pencil case", + "pencil sharpener", + "perfume, essence", + "Petri dish", + "photocopier", + "pick, plectrum, plectron", + "pickelhaube", + "picket fence, paling", + "pickup, pickup truck", + "pier", + "piggy bank, penny bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate, pirate ship", + "pitcher, ewer", + "plane, carpenter's plane, woodworking plane", + "planetarium", + "plastic bag", + "plate rack", + "plow, plough", + "plunger, plumber's helper", + "Polaroid camera, Polaroid Land camera", + "pole", + "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "poncho", + "pool table, billiard table, snooker table", + "pop bottle, soda bottle", + "pot, flowerpot", + "potter's wheel", + "power drill", + "prayer rug, prayer mat", + "printer", + "prison, prison house", + "projectile, missile", + "projector", + "puck, hockey puck", + "punching bag, punch bag, punching ball, punchball", + "purse", + "quill, quill pen", + "quilt, comforter, comfort, puff", + "racer, race car, racing car", + "racket, racquet", + "radiator", + "radio, wireless", + "radio telescope, radio reflector", + "rain barrel", + "recreational vehicle, RV, R.V.", + "reel", + "reflex camera", + "refrigerator, icebox", + "remote control, remote", + "restaurant, eating house, eating place, eatery", + "revolver, six-gun, six-shooter", + "rifle", + "rocking chair, rocker", + "rotisserie", + "rubber eraser, rubber, pencil eraser", + "rugby ball", + "rule, ruler", + "running shoe", + "safe", + "safety pin", + "saltshaker, salt shaker", + "sandal", + "sarong", + "sax, saxophone", + "scabbard", + "scale, weighing machine", + "school bus", + "schooner", + "scoreboard", + "screen, CRT screen", + "screw", + "screwdriver", + "seat belt, seatbelt", + "sewing machine", + "shield, buckler", + "shoe shop, shoe-shop, shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule, slipstick", + "sliding door", + "slot, one-armed bandit", + "snorkel", + "snowmobile", + "snowplow, snowplough", + "soap dispenser", + "soccer ball", + "sock", + "solar dish, solar collector, solar furnace", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web, spider's web", + "spindle", + "sports car, sport car", + "spotlight, spot", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch, stop watch", + "stove", + "strainer", + "streetcar, tram, tramcar, trolley, trolley car", + "stretcher", + "studio couch, day bed", + "stupa, tope", + "submarine, pigboat, sub, U-boat", + "suit, suit of clothes", + "sundial", + "sunglass", + "sunglasses, dark glasses, shades", + "sunscreen, sunblock, sun blocker", + "suspension bridge", + "swab, swob, mop", + "sweatshirt", + "swimming trunks, bathing trunks", + "swing", + "switch, electric switch, electrical switch", + "syringe", + "table lamp", + "tank, army tank, armored combat vehicle, armoured combat vehicle", + "tape player", + "teapot", + "teddy, teddy bear", + "television, television system", + "tennis ball", + "thatch, thatched roof", + "theater curtain, theatre curtain", + "thimble", + "thresher, thrasher, threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop, tobacconist shop, tobacconist", + "toilet seat", + "torch", + "totem pole", + "tow truck, tow car, wrecker", + "toyshop", + "tractor", + "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "tray", + "trench coat", + "tricycle, trike, velocipede", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus, trolley coach, trackless trolley", + "trombone", + "tub, vat", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle, monocycle", + "upright, upright piano", + "vacuum, vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin, fiddle", + "volleyball", + "waffle iron", + "wall clock", + "wallet, billfold, notecase, pocketbook", + "wardrobe, closet, press", + "warplane, military plane", + "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "washer, automatic washer, washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool, woolen, woollen", + "worm fence, snake fence, snake-rail fence, Virginia fence", + "wreck", + "yawl", + "yurt", + "web site, website, internet site, site", + "comic book", + "crossword puzzle, crossword", + "street sign", + "traffic light, traffic signal, stoplight", + "book jacket, dust cover, dust jacket, dust wrapper", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot, hotpot", + "trifle", + "ice cream, icecream", + "ice lolly, lolly, lollipop, popsicle", + "French loaf", + "bagel, beigel", + "pretzel", + "cheeseburger", + "hotdog, hot dog, red hot", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini, courgette", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber, cuke", + "artichoke, globe artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple, ananas", + "banana", + "jackfruit, jak, jack", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce, chocolate syrup", + "dough", + "meat loaf, meatloaf", + "pizza, pizza pie", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff, drop, drop-off", + "coral reef", + "geyser", + "lakeside, lakeshore", + "promontory, headland, head, foreland", + "sandbar, sand bar", + "seashore, coast, seacoast, sea-coast", + "valley, vale", + "volcano", + "ballplayer, baseball player", + "groom, bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "corn", + "acorn", + "hip, rose hip, rosehip", + "buckeye, horse chestnut, conker", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn, carrion fungus", + "earthstar", + "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "bolete", + "ear, spike, capitulum", + "toilet tissue, toilet paper, bathroom tissue" +]; diff --git a/src/core/logits_sampler.rs b/src/misc/logits_sampler.rs similarity index 91% rename from src/core/logits_sampler.rs rename to src/misc/logits_sampler.rs index 5867fd7..5795834 100644 --- a/src/core/logits_sampler.rs +++ b/src/misc/logits_sampler.rs @@ -1,8 +1,7 @@ use anyhow::Result; use rand::distributions::{Distribution, WeightedIndex}; -/// Logits Sampler -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LogitsSampler { temperature: f32, p: f32, @@ -32,7 +31,7 @@ impl LogitsSampler { self } - pub fn decode(&mut self, logits: &[f32]) -> Result { + pub fn decode(&self, logits: &[f32]) -> Result { if self.p == 0.0 { self.search_by_argmax(logits) } else { @@ -40,7 +39,7 @@ impl LogitsSampler { } } - fn search_by_argmax(&mut self, logits: &[f32]) -> Result { + fn search_by_argmax(&self, logits: &[f32]) -> Result { // no need to do softmax let (token_id, _) = logits .iter() @@ -50,7 +49,7 @@ impl LogitsSampler { Ok(token_id as u32) } - fn sample_by_topp(&mut self, logits: &[f32]) -> Result { + fn sample_by_topp(&self, logits: &[f32]) -> Result { let logits = self.softmax(logits); let mut logits: Vec<(usize, f32)> = logits .iter() diff --git a/src/core/media.rs b/src/misc/media.rs similarity index 100% rename from src/core/media.rs rename to src/misc/media.rs index 23cee6a..ee76c69 100644 --- a/src/core/media.rs +++ b/src/misc/media.rs @@ -1,14 +1,5 @@ use crate::{AUDIO_EXTENSIONS, IMAGE_EXTENSIONS, STREAM_PROTOCOLS, VIDEO_EXTENSIONS}; -#[derive(Debug, Clone)] -pub enum MediaType { - Image(Location), - Video(Location), - Audio(Location), - Stream, - Unknown, -} - #[derive(Debug, Clone)] pub enum Location { Local, @@ -21,6 +12,15 @@ pub enum StreamType { Live, } +#[derive(Debug, Clone)] +pub enum MediaType { + Image(Location), + Video(Location), + Audio(Location), + Stream, + Unknown, +} + impl MediaType { pub fn from_path>(path: P) -> Self { let extension = path diff --git a/src/core/min_opt_max.rs b/src/misc/min_opt_max.rs similarity index 97% rename from src/core/min_opt_max.rs rename to src/misc/min_opt_max.rs index c1c5b45..e4f47ca 100644 --- a/src/core/min_opt_max.rs +++ b/src/misc/min_opt_max.rs @@ -1,8 +1,13 @@ +use aksr::Builder; + /// A value composed of Min-Opt-Max -#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] +#[derive(Builder, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] pub struct MinOptMax { + #[args(setter = false)] min: usize, + #[args(setter = false)] opt: usize, + #[args(setter = false)] max: usize, } @@ -17,18 +22,6 @@ impl Default for MinOptMax { } impl MinOptMax { - pub fn min(&self) -> usize { - self.min - } - - pub fn opt(&self) -> usize { - self.opt - } - - pub fn max(&self) -> usize { - self.max - } - pub fn ones() -> Self { Default::default() } @@ -90,6 +83,7 @@ impl MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: i32) -> Self { let opt = opt.max(0) as usize; @@ -99,6 +93,7 @@ impl From for MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: i64) -> Self { let opt = opt.max(0) as usize; @@ -134,6 +129,7 @@ impl From for MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: isize) -> Self { let opt = opt.max(0) as usize; diff --git a/src/misc/mod.rs b/src/misc/mod.rs new file mode 100644 index 0000000..2f89319 --- /dev/null +++ b/src/misc/mod.rs @@ -0,0 +1,59 @@ +mod annotator; +mod color; +mod colormap256; +mod dataloader; +mod device; +mod dir; +mod dtype; +mod dynconf; +mod engine; +mod hub; +mod iiix; +mod kind; +mod labels; +mod logits_sampler; +mod media; +mod min_opt_max; +pub(crate) mod onnx; +mod ops; +mod options; +mod processor; +mod retry; +mod scale; +mod task; +mod ts; +mod utils; +mod version; +#[cfg(feature = "ffmpeg")] +mod viewer; + +pub use annotator::Annotator; +pub use color::Color; +pub use colormap256::*; +pub use dataloader::DataLoader; +pub use device::Device; +pub use dir::Dir; +pub use dtype::DType; +pub use dynconf::DynConf; +pub use engine::*; +pub use hub::Hub; +pub use iiix::Iiix; +pub use kind::Kind; +pub use labels::*; +pub use logits_sampler::LogitsSampler; +pub use media::*; +pub use min_opt_max::MinOptMax; +pub use ops::*; +pub use options::*; +pub use processor::*; +pub use scale::Scale; +pub use task::Task; +pub use ts::Ts; +pub use utils::*; +pub use version::Version; +#[cfg(feature = "ffmpeg")] +pub use viewer::Viewer; + +// re-export +#[cfg(feature = "ffmpeg")] +pub use minifb::Key; diff --git a/src/core/onnx.rs b/src/misc/onnx.rs similarity index 99% rename from src/core/onnx.rs rename to src/misc/onnx.rs index d88dc84..33bdfc0 100644 --- a/src/core/onnx.rs +++ b/src/misc/onnx.rs @@ -866,6 +866,7 @@ pub mod type_proto { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] + #[allow(clippy::enum_variant_names)] pub enum Value { /// The type of a tensor. #[prost(message, tag = "1")] @@ -945,6 +946,7 @@ pub struct FunctionProto { /// that is not defined by the default value but an explicit enum number. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +#[allow(clippy::enum_variant_names)] pub enum Version { /// proto3 requires the first enum value to be zero. /// We add this just to appease the compiler. diff --git a/src/core/ops.rs b/src/misc/ops.rs similarity index 65% rename from src/core/ops.rs rename to src/misc/ops.rs index 7b60fb3..e30a929 100644 --- a/src/core/ops.rs +++ b/src/misc/ops.rs @@ -1,4 +1,4 @@ -//! Some processing functions to image and ndarray. +//! Some processing functions. use anyhow::Result; use fast_image_resize::{ @@ -7,11 +7,11 @@ use fast_image_resize::{ FilterType, ResizeAlg, ResizeOptions, Resizer, }; use image::{DynamicImage, GenericImageView}; -use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, IxDyn}; +use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, Ix2, IxDyn}; use rayon::prelude::*; pub enum Ops<'a> { - Resize(&'a [DynamicImage], u32, u32, &'a str), + FitExact(&'a [DynamicImage], u32, u32, &'a str), Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool), Normalize(f32, f32), Standardize(&'a [f32], &'a [f32], usize), @@ -80,11 +80,20 @@ impl Ops<'_> { dim: usize, ) -> Result> { if mean.len() != std.len() { - anyhow::bail!("`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", mean.len(), std.len()); + anyhow::bail!( + "`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", + mean.len(), + std.len() + ); } let shape = x.shape(); if dim >= shape.len() || shape[dim] != mean.len() { - anyhow::bail!("`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", dim, shape.len(), mean.len()); + anyhow::bail!( + "`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", + dim, + shape.len(), + mean.len() + ); } let mut shape = vec![1; shape.len()]; shape[dim] = mean.len(); @@ -122,6 +131,23 @@ impl Ops<'_> { Ok(concatenate(Axis(d), &[x.view(), y.view()])?) } + pub fn concat(xs: &[Array], d: usize) -> Result> { + let xs = xs.iter().map(|x| x.view()).collect::>(); + Ok(concatenate(Axis(d), &xs)?) + } + + pub fn dot2(x: &Array, other: &Array) -> Result>> { + // (m, ndim) * (n, ndim).t => (m, n) + let query = x.to_owned().into_dimensionality::()?; + let gallery = other.to_owned().into_dimensionality::()?; + let matrix = query.dot(&gallery.t()); + let exps = matrix.mapv(|x| x.exp()); + let stds = exps.sum_axis(Axis(1)); + let matrix = exps / stds.insert_axis(Axis(1)); + let matrix: Vec> = matrix.axis_iter(Axis(0)).map(|row| row.to_vec()).collect(); + Ok(matrix) + } + pub fn insert_axis(x: Array, d: usize) -> Result> { if x.shape().len() < d { anyhow::bail!( @@ -155,7 +181,8 @@ impl Ops<'_> { } pub fn make_divisible(x: usize, divisor: usize) -> usize { - (x + divisor - 1) / divisor * divisor + // (x + divisor - 1) / divisor * divisor + x.div_ceil(divisor) * divisor } // deprecated @@ -167,11 +194,6 @@ impl Ops<'_> { mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle) } - // pub fn argmax(xs: Array, d: usize, keep_dims: bool) -> Result> { - // let mask = Array::zeros(xs.raw_dim()); - // todo!(); - // } - pub fn interpolate_3d( xs: Array, tw: f32, @@ -238,14 +260,36 @@ impl Ops<'_> { }; resizer.resize(&src, &mut dst, &options)?; - // u8*2 -> f32 - let mask_f32: Vec = dst - .into_vec() - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); + // u8 -> f32 + Self::u8_slice_to_f32(&dst.into_vec()) + } + + pub fn u8_slice_to_f32(data: &[u8]) -> Result> { + let size_in_bytes = 4; + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + let data: &[f32] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, elem_count) }; - Ok(mask_f32) + Ok(data.to_vec()) + } else { + let mut c: Vec = Vec::with_capacity(elem_count); + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + + Ok(c) + } + } + + pub fn f32_slice_to_u8(mut vs: Vec) -> Vec { + let size_in_bytes = 4; + let length = vs.len() * size_in_bytes; + let capacity = vs.capacity() * size_in_bytes; + let ptr = vs.as_mut_ptr() as *mut u8; + std::mem::forget(vs); + unsafe { Vec::from_raw_parts(ptr, length, capacity) } } pub fn resize_luma8_u8( @@ -285,6 +329,26 @@ impl Ops<'_> { )) } + pub fn resize_rgb( + x: &DynamicImage, + th: u32, + tw: u32, + resizer: &mut Resizer, + options: &ResizeOptions, + ) -> Result> { + let buffer = if x.dimensions() == (tw, th) { + x.to_rgb8().into_raw() + } else { + let mut dst = Image::new(tw, th, PixelType::U8x3); + resizer.resize(x, &mut dst, options)?; + dst.into_vec() + }; + let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? + .mapv(|x| x as f32) + .into_dyn(); + Ok(y) + } + pub fn resize( xs: &[DynamicImage], th: u32, @@ -294,20 +358,65 @@ impl Ops<'_> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { - let buffer = if x.dimensions() == (tw, th) { - x.to_rgb8().into_raw() - } else { - let mut dst = Image::new(tw, th, PixelType::U8x3); - resizer.resize(x, &mut dst, &options)?; - dst.into_vec() - }; - let y_ = - Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); - ys.slice_mut(s![idx, .., .., ..]).assign(&y_); + let y = Self::resize_rgb(x, th, tw, &mut resizer, &options)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) } + #[allow(clippy::too_many_arguments)] + pub fn letterbox_rgb( + x: &DynamicImage, + th: u32, + tw: u32, + bg: u8, + resize_by: &str, + center: bool, + resizer: &mut Resizer, + options: &ResizeOptions, + ) -> Result> { + let (w0, h0) = x.dimensions(); + let buffer = if w0 == tw && h0 == th { + x.to_rgb8().into_raw() + } else { + let (w, h) = match resize_by { + "auto" => { + let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32); + ( + (w0 as f32 * r).round() as u32, + (h0 as f32 * r).round() as u32, + ) + } + "height" => (th * w0 / h0, th), + "width" => (tw, tw * h0 / w0), + _ => anyhow::bail!("ModelConfig for `letterbox`: width, height, auto"), + }; + + let mut dst = Image::from_vec_u8( + tw, + th, + vec![bg; 3 * th as usize * tw as usize], + PixelType::U8x3, + )?; + let (l, t) = if center { + if w == tw { + (0, (th - h) / 2) + } else { + ((tw - w) / 2, 0) + } + } else { + (0, 0) + }; + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, options)?; + dst.into_vec() + }; + let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? + .mapv(|x| x as f32) + .into_dyn(); + Ok(y) + } + pub fn letterbox( xs: &[DynamicImage], th: u32, @@ -319,47 +428,9 @@ impl Ops<'_> { ) -> Result> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; - for (idx, x) in xs.iter().enumerate() { - let (w0, h0) = x.dimensions(); - let buffer = if w0 == tw && h0 == th { - x.to_rgb8().into_raw() - } else { - let (w, h) = match resize_by { - "auto" => { - let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32); - ( - (w0 as f32 * r).round() as u32, - (h0 as f32 * r).round() as u32, - ) - } - "height" => (th * w0 / h0, th), - "width" => (tw, tw * h0 / w0), - _ => anyhow::bail!("Options for `letterbox`: width, height, auto"), - }; - - let mut dst = Image::from_vec_u8( - tw, - th, - vec![bg; 3 * th as usize * tw as usize], - PixelType::U8x3, - )?; - let (l, t) = if center { - if w == tw { - (0, (th - h) / 2) - } else { - ((tw - w) / 2, 0) - } - } else { - (0, 0) - }; - let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; - resizer.resize(x, &mut dst_cropped, &options)?; - dst.into_vec() - }; - let y_ = - Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); - ys.slice_mut(s![idx, .., .., ..]).assign(&y_); + let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &options)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) } diff --git a/src/misc/options.rs b/src/misc/options.rs new file mode 100644 index 0000000..bc98179 --- /dev/null +++ b/src/misc/options.rs @@ -0,0 +1,461 @@ +//! Options for everthing + +use aksr::Builder; +use anyhow::Result; +use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; + +use crate::{ + models::{SamKind, YOLOPredsFormat}, + DType, Device, Engine, Hub, Iiix, Kind, LogitsSampler, MinOptMax, Processor, ResizeMode, Scale, + Task, Version, +}; + +/// Options for building models and inference +#[derive(Builder, Debug, Clone)] +pub struct Options { + // Model configs + pub model_file: String, + pub model_name: &'static str, + pub model_device: Device, + pub model_dtype: DType, + pub model_version: Option, + pub model_task: Option, + pub model_scale: Option, + pub model_kind: Option, + pub model_iiixs: Vec, + pub model_spec: String, + pub model_num_dry_run: usize, + pub trt_fp16: bool, + pub profile: bool, + + // Processor configs + #[args(setter = false)] + pub image_width: u32, + #[args(setter = false)] + pub image_height: u32, + pub resize_mode: ResizeMode, + pub resize_filter: &'static str, + pub padding_value: u8, + pub letterbox_center: bool, + pub normalize: bool, + pub image_std: Vec, + pub image_mean: Vec, + pub nchw: bool, + pub unsigned: bool, + + // Names + pub class_names: Option>, + pub class_names_2: Option>, + pub class_names_3: Option>, + pub keypoint_names: Option>, + pub keypoint_names_2: Option>, + pub keypoint_names_3: Option>, + pub text_names: Option>, + pub text_names_2: Option>, + pub text_names_3: Option>, + pub category_names: Option>, + pub category_names_2: Option>, + pub category_names_3: Option>, + + // Confs + pub class_confs: Vec, + pub class_confs_2: Vec, + pub class_confs_3: Vec, + pub keypoint_confs: Vec, + pub keypoint_confs_2: Vec, + pub keypoint_confs_3: Vec, + pub text_confs: Vec, + pub text_confs_2: Vec, + pub text_confs_3: Vec, + + // For classification + pub apply_softmax: Option, + + // For detection + #[args(alias = "nc")] + pub num_classes: Option, + #[args(alias = "nk")] + pub num_keypoints: Option, + #[args(alias = "nm")] + pub num_masks: Option, + pub iou: Option, + pub iou_2: Option, + pub iou_3: Option, + pub apply_nms: Option, + pub find_contours: bool, + pub yolo_preds_format: Option, + pub classes_excluded: Vec, + pub classes_retained: Vec, + pub min_width: Option, + pub min_height: Option, + + // Language models related + pub model_max_length: Option, + pub tokenizer_file: Option, + pub config_file: Option, + pub special_tokens_map_file: Option, + pub tokenizer_config_file: Option, + pub generation_config_file: Option, + pub vocab_file: Option, // vocab.json file + pub vocab_txt: Option, // vacab.txt file, not kv pairs + pub temperature: f32, + pub topp: f32, + + // For DB + pub unclip_ratio: Option, + pub binary_thresh: Option, + + // For SAM + pub sam_kind: Option, + pub low_res_mask: Option, +} + +impl Default for Options { + fn default() -> Self { + Self { + model_file: Default::default(), + model_name: Default::default(), + model_version: Default::default(), + model_task: Default::default(), + model_scale: Default::default(), + model_kind: Default::default(), + model_device: Device::Cpu(0), + model_dtype: DType::Auto, + model_spec: Default::default(), + model_iiixs: Default::default(), + model_num_dry_run: 3, + trt_fp16: true, + profile: false, + normalize: true, + image_mean: vec![], + image_std: vec![], + image_height: 640, + image_width: 640, + padding_value: 114, + resize_mode: ResizeMode::FitExact, + resize_filter: "Bilinear", + letterbox_center: false, + nchw: true, + unsigned: false, + class_names: None, + class_names_2: None, + class_names_3: None, + category_names: None, + category_names_2: None, + category_names_3: None, + keypoint_names: None, + keypoint_names_2: None, + keypoint_names_3: None, + text_names: None, + text_names_2: None, + text_names_3: None, + class_confs: vec![0.3f32], + class_confs_2: vec![0.3f32], + class_confs_3: vec![0.3f32], + keypoint_confs: vec![0.3f32], + keypoint_confs_2: vec![0.5f32], + keypoint_confs_3: vec![0.5f32], + text_confs: vec![0.4f32], + text_confs_2: vec![0.4f32], + text_confs_3: vec![0.4f32], + apply_softmax: Some(false), + num_classes: None, + num_keypoints: None, + num_masks: None, + iou: None, + iou_2: None, + iou_3: None, + find_contours: false, + yolo_preds_format: None, + classes_excluded: vec![], + classes_retained: vec![], + apply_nms: None, + model_max_length: None, + tokenizer_file: None, + config_file: None, + special_tokens_map_file: None, + tokenizer_config_file: None, + generation_config_file: None, + vocab_file: None, + vocab_txt: None, + min_width: None, + min_height: None, + unclip_ratio: Some(1.5), + binary_thresh: Some(0.2), + sam_kind: None, + low_res_mask: None, + temperature: 1., + topp: 0., + } + } +} + +impl Options { + pub fn new() -> Self { + Default::default() + } + + pub fn to_engine(&self) -> Result { + Engine { + file: self.model_file.clone(), + spec: self.model_spec.clone(), + device: self.model_device, + trt_fp16: self.trt_fp16, + iiixs: self.model_iiixs.clone(), + num_dry_run: self.model_num_dry_run, + ..Default::default() + } + .build() + } + + pub fn to_processor(&self) -> Result { + let logits_sampler = LogitsSampler::new() + .with_temperature(self.temperature) + .with_topp(self.topp); + + // try to build tokenizer + let tokenizer = match self.model_kind { + Some(Kind::Language) | Some(Kind::VisionLanguage) => Some(self.try_build_tokenizer()?), + _ => None, + }; + + // try to build vocab from `vocab.txt` + let vocab: Vec = match &self.vocab_txt { + Some(x) => { + let file = if !std::path::PathBuf::from(&x).exists() { + Hub::default().try_fetch(&format!("{}/{}", self.model_name, x))? + } else { + x.to_string() + }; + std::fs::read_to_string(file)? + .lines() + .map(|line| line.to_string()) + .collect() + } + None => vec![], + }; + + Ok(Processor { + image_width: self.image_width, + image_height: self.image_height, + resize_mode: self.resize_mode.clone(), + resize_filter: self.resize_filter, + padding_value: self.padding_value, + do_normalize: self.normalize, + image_mean: self.image_mean.clone(), + image_std: self.image_std.clone(), + nchw: self.nchw, + unsigned: self.unsigned, + tokenizer, + vocab, + logits_sampler: Some(logits_sampler), + ..Default::default() + }) + } + + pub fn commit(mut self) -> Result { + // Identify the local model or fetch the remote model + + if std::path::PathBuf::from(&self.model_file).exists() { + // Local + self.model_spec = format!( + "{}/{}", + self.model_name, + crate::try_fetch_stem(&self.model_file)? + ); + } else { + // Remote + if self.model_file.is_empty() && self.model_name.is_empty() { + anyhow::bail!("Neither `model_name` nor `model_file` were specified. Faild to fetch model from remote.") + } + + // Load + match Hub::is_valid_github_release_url(&self.model_file) { + Some((owner, repo, tag, _file_name)) => { + let stem = crate::try_fetch_stem(&self.model_file)?; + self.model_spec = + format!("{}/{}-{}-{}-{}", self.model_name, owner, repo, tag, stem); + self.model_file = Hub::default().try_fetch(&self.model_file)?; + } + None => { + // special yolo case + if self.model_file.is_empty() && self.model_name == "yolo" { + // [version]-[scale]-[task] + let mut y = String::new(); + if let Some(x) = self.model_version() { + y.push_str(&x.to_string()); + } + if let Some(x) = self.model_scale() { + y.push_str(&format!("-{}", x)); + } + if let Some(x) = self.model_task() { + y.push_str(&format!("-{}", x.yolo_str())); + } + y.push_str(".onnx"); + self.model_file = y; + } + + // append dtype to model file + match self.model_dtype { + d @ (DType::Auto | DType::Fp32) => { + if self.model_file.is_empty() { + self.model_file = format!("{}.onnx", d); + } + } + dtype => { + if self.model_file.is_empty() { + self.model_file = format!("{}.onnx", dtype); + } else { + let pos = self.model_file.len() - 5; // .onnx + let suffix = self.model_file.split_off(pos); + self.model_file = + format!("{}-{}{}", self.model_file, dtype, suffix); + } + } + } + + let stem = crate::try_fetch_stem(&self.model_file)?; + self.model_spec = format!("{}/{}", self.model_name, stem); + self.model_file = Hub::default() + .try_fetch(&format!("{}/{}", self.model_name, self.model_file))?; + } + } + + // let stem = crate::try_fetch_stem(&self.model_file)?; + // self.model_spec = format!("{}/{}", self.model_name, stem); + // self.model_file = + // Hub::default().try_fetch(&format!("{}/{}", self.model_name, self.model_file))?; + } + + Ok(self) + } + + pub fn with_batch_size(mut self, x: usize) -> Self { + self.model_iiixs.push(Iiix::from((0, 0, x.into()))); + self + } + + pub fn with_image_height(mut self, x: u32) -> Self { + self.image_height = x; + self.model_iiixs.push(Iiix::from((0, 2, x.into()))); + self + } + + pub fn with_image_width(mut self, x: u32) -> Self { + self.image_width = x; + self.model_iiixs.push(Iiix::from((0, 3, x.into()))); + self + } + + pub fn with_model_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { + self.model_iiixs.push(Iiix::from((i, ii, x))); + self + } + + pub fn exclude_classes(mut self, xs: &[usize]) -> Self { + self.classes_retained.clear(); + self.classes_excluded.extend_from_slice(xs); + self + } + + pub fn retain_classes(mut self, xs: &[usize]) -> Self { + self.classes_excluded.clear(); + self.classes_retained.extend_from_slice(xs); + self + } + + pub fn try_build_tokenizer(&self) -> Result { + let mut hub = Hub::default(); + // config file + // TODO: save configs? + let pad_id = match hub.try_fetch( + self.tokenizer_config_file + .as_ref() + .unwrap_or(&format!("{}/config.json", self.model_name)), + ) { + Ok(x) => { + let config: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(x)?)?; + config["pad_token_id"].as_u64().unwrap_or(0) as u32 + } + Err(_err) => 0u32, + }; + + // tokenizer_config file + let mut max_length = None; + let mut pad_token = String::from("[PAD]"); + match hub.try_fetch( + self.tokenizer_config_file + .as_ref() + .unwrap_or(&format!("{}/tokenizer_config.json", self.model_name)), + ) { + Err(_) => {} + Ok(x) => { + let tokenizer_config: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(x)?)?; + max_length = tokenizer_config["model_max_length"].as_u64(); + pad_token = tokenizer_config["pad_token"] + .as_str() + .unwrap_or("[PAD]") + .to_string(); + } + } + + // tokenizer file + let mut tokenizer: tokenizers::Tokenizer = tokenizers::Tokenizer::from_file( + hub.try_fetch( + self.tokenizer_file + .as_ref() + .unwrap_or(&format!("{}/tokenizer.json", self.model_name)), + )?, + ) + .map_err(|_| anyhow::anyhow!("No `tokenizer.json` found"))?; + + // TODO: padding + // if `max_length` specified: use `Fixed` strategy + // else: use `BatchLongest` strategy + // TODO: if sequence_length is dynamic, `BatchLongest` is fine + let tokenizer = match self.model_max_length { + Some(n) => { + let n = match max_length { + None => n, + Some(x) => x.min(n), + }; + tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::Fixed(n as _), + pad_token, + pad_id, + ..Default::default() + })) + .clone() + } + None => match max_length { + Some(n) => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: n as _, + ..Default::default() + })) + .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))? + .clone(), + None => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .clone(), + }, + }; + + // TODO: generation_config.json & special_tokens_map file + + Ok(tokenizer.into()) + } +} diff --git a/src/misc/processor.rs b/src/misc/processor.rs new file mode 100644 index 0000000..3ada0bb --- /dev/null +++ b/src/misc/processor.rs @@ -0,0 +1,366 @@ +use anyhow::Result; +use fast_image_resize::{ + images::{CroppedImageMut, Image}, + pixels::PixelType, + FilterType, ResizeAlg, ResizeOptions, Resizer, +}; +use image::{DynamicImage, GenericImageView}; +use ndarray::{s, Array, Axis}; +use tokenizers::{Encoding, Tokenizer}; + +use crate::{LogitsSampler, X}; + +#[derive(Debug, Clone)] +pub enum ResizeMode { + FitExact, // StretchToFit + FitWidth, + FitHeight, + FitAdaptive, + Letterbox, +} + +#[derive(aksr::Builder, Debug, Clone)] +pub struct Processor { + pub image_width: u32, // target image width + pub image_height: u32, // target image height + pub image0s_size: Vec<(u32, u32)>, // original image height and width + pub scale_factors_hw: Vec>, + pub resize_mode: ResizeMode, + pub resize_filter: &'static str, + pub padding_value: u8, + pub do_normalize: bool, + pub image_mean: Vec, + pub image_std: Vec, + pub nchw: bool, + pub tokenizer: Option, + pub vocab: Vec, + pub unsigned: bool, + pub logits_sampler: Option, +} + +impl Default for Processor { + fn default() -> Self { + Self { + image0s_size: vec![], + image_width: 0, + image_height: 0, + scale_factors_hw: vec![], + resize_mode: ResizeMode::FitAdaptive, + resize_filter: "Bilinear", + padding_value: 114, + do_normalize: true, + image_mean: vec![], + image_std: vec![], + nchw: true, + tokenizer: Default::default(), + vocab: vec![], + unsigned: false, + logits_sampler: None, + } + } +} + +impl Processor { + pub fn reset_image0_status(&mut self) { + self.scale_factors_hw.clear(); + self.image0s_size.clear(); + } + + pub fn process_images(&mut self, xs: &[DynamicImage]) -> Result { + // reset + self.reset_image0_status(); + + let mut x = self.resize_batch(xs)?; + if self.do_normalize { + x = x.normalize(0., 255.)?; + } + if !self.image_std.is_empty() && !self.image_mean.is_empty() { + x = x.standardize(&self.image_mean, &self.image_std, 3)?; + } + if self.nchw { + x = x.nhwc2nchw()?; + } + + // Cope with padding problem + if self.unsigned { + x = x.unsigned(); + } + Ok(x) + } + + pub fn encode_text(&self, x: &str, skip_special_tokens: bool) -> Result { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .encode(x, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer encode error: {}", err)) + } + + pub fn encode_texts(&self, xs: &[&str], skip_special_tokens: bool) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .encode_batch(xs.to_vec(), skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer encode_batch error: {}", err)) + } + + pub fn encode_text_ids(&self, x: &str, skip_special_tokens: bool) -> Result> { + let ids: Vec = if x.is_empty() { + vec![0.0f32] + } else { + self.encode_text(x, skip_special_tokens)? + .get_ids() + .iter() + .map(|x| *x as f32) + .collect() + }; + + Ok(ids) + } + + pub fn encode_texts_ids( + &self, + xs: &[&str], + skip_special_tokens: bool, + ) -> Result>> { + let ids: Vec> = if xs.is_empty() { + vec![vec![0.0f32]] + } else { + self.encode_texts(xs, skip_special_tokens)? + .into_iter() + .map(|encoding| encoding.get_ids().iter().map(|x| *x as f32).collect()) + .collect() + }; + + Ok(ids) + } + + pub fn encode_text_tokens(&self, x: &str, skip_special_tokens: bool) -> Result> { + Ok(self + .encode_text(x, skip_special_tokens)? + .get_tokens() + .to_vec()) + } + + pub fn encode_texts_tokens( + &self, + xs: &[&str], + skip_special_tokens: bool, + ) -> Result>> { + Ok(self + .encode_texts(xs, skip_special_tokens)? + .into_iter() + .map(|encoding| encoding.get_tokens().to_vec()) + .collect()) + } + + pub fn decode_tokens(&self, ids: &[u32], skip_special_tokens: bool) -> Result { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode(ids, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer decode error: {}", err)) + } + + pub fn decode_tokens_batch2( + &self, + ids: &[&[u32]], + skip_special_tokens: bool, + ) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode_batch(ids, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer decode_batch error: {}", err)) + } + + pub fn decode_tokens_batch( + &self, + ids: &[Vec], + skip_special_tokens: bool, + ) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode_batch( + &ids.iter().map(|x| x.as_slice()).collect::>(), + skip_special_tokens, + ) + .map_err(|err| anyhow::anyhow!("Tokenizer decode_batch error: {}", err)) + } + + pub fn par_generate( + &self, + logits: &X, + token_ids: &mut [Vec], + eos_token_id: u32, + ) -> Result<(bool, Vec)> { + // token ids + // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + // let mut finished = vec![false; self.encoder.batch()]; + let batch = token_ids.len(); + let mut finished = vec![false; batch]; + let mut last_tokens: Vec = vec![0.; batch]; + // let mut logits_sampler = LogitsSampler::new(); + + // decode each token for each batch + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = self + .logits_sampler + .as_ref() + .expect("No `LogitsSampler` specified!") + .decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + + if token_id == eos_token_id { + finished[i] = true; + } else { + token_ids[i].push(token_id); + } + + // update + last_tokens[i] = token_id as f32; + } + } + + // all finished? + Ok((finished.iter().all(|&x| x), last_tokens)) + } + + pub fn build_resizer_filter(ty: &str) -> Result<(Resizer, ResizeOptions)> { + let ty = match ty.to_lowercase().as_str() { + "box" => FilterType::Box, + "bilinear" => FilterType::Bilinear, + "hamming" => FilterType::Hamming, + "catmullrom" => FilterType::CatmullRom, + "mitchell" => FilterType::Mitchell, + "gaussian" => FilterType::Gaussian, + "lanczos3" => FilterType::Lanczos3, + x => anyhow::bail!("Unsupported resizer's filter type: {}", x), + }; + Ok(( + Resizer::new(), + ResizeOptions::new().resize_alg(ResizeAlg::Convolution(ty)), + )) + } + + pub fn resize(&mut self, x: &DynamicImage) -> Result { + if self.image_width + self.image_height == 0 { + anyhow::bail!( + "Invalid target height: {} or width: {}.", + self.image_height, + self.image_width + ); + } + + let buffer = match x.dimensions() { + (w, h) if (w, h) == (self.image_height, self.image_width) => { + self.image0s_size.push((h, w)); + self.scale_factors_hw.push(vec![1., 1.]); + x.to_rgb8().into_raw() + } + (w0, h0) => { + self.image0s_size.push((h0, w0)); + let (mut resizer, options) = Self::build_resizer_filter(self.resize_filter)?; + + if let ResizeMode::FitExact = self.resize_mode { + let mut dst = Image::new(self.image_width, self.image_height, PixelType::U8x3); + resizer.resize(x, &mut dst, &options)?; + self.scale_factors_hw.push(vec![ + (self.image_height as f32 / h0 as f32), + (self.image_width as f32 / w0 as f32), + ]); + + dst.into_vec() + } else { + let (w, h) = match self.resize_mode { + ResizeMode::Letterbox | ResizeMode::FitAdaptive => { + let r = (self.image_width as f32 / w0 as f32) + .min(self.image_height as f32 / h0 as f32); + self.scale_factors_hw.push(vec![r, r]); + + ( + (w0 as f32 * r).round() as u32, + (h0 as f32 * r).round() as u32, + ) + } + ResizeMode::FitHeight => { + let r = self.image_height as f32 / h0 as f32; + self.scale_factors_hw.push(vec![1.0, r]); + ((r * w0 as f32).round() as u32, self.image_height) + } + ResizeMode::FitWidth => { + // scale factor + let r = self.image_width as f32 / w0 as f32; + self.scale_factors_hw.push(vec![r, 1.0]); + (self.image_width, (r * h0 as f32).round() as u32) + } + + _ => unreachable!(), + }; + + let mut dst = Image::from_vec_u8( + self.image_width, + self.image_height, + vec![ + self.padding_value; + 3 * self.image_height as usize * self.image_width as usize + ], + PixelType::U8x3, + )?; + let (l, t) = if let ResizeMode::Letterbox = self.resize_mode { + if w == self.image_width { + (0, (self.image_height - h) / 2) + } else { + ((self.image_width - w) / 2, 0) + } + } else { + (0, 0) + }; + + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, &options)?; + dst.into_vec() + } + } + }; + + let y = Array::from_shape_vec( + (self.image_height as usize, self.image_width as usize, 3), + buffer, + )? + .mapv(|x| x as f32) + .into_dyn(); + + Ok(y.into()) + } + + pub fn resize_batch(&mut self, xs: &[DynamicImage]) -> Result { + // TODO: par resize + if xs.is_empty() { + anyhow::bail!("Found no input images.") + } + + let mut ys = Array::ones(( + xs.len(), + self.image_height as usize, + self.image_width as usize, + 3, + )) + .into_dyn(); + + xs.iter().enumerate().try_for_each(|(idx, x)| { + let y = self.resize(x)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); + anyhow::Ok(()) + })?; + + Ok(ys.into()) + } +} diff --git a/src/misc/retry.rs b/src/misc/retry.rs new file mode 100644 index 0000000..82f63c8 --- /dev/null +++ b/src/misc/retry.rs @@ -0,0 +1,144 @@ +/// A macro to retry an expression multiple times with configurable delays between attempts. +/// +/// This macro supports three forms: +/// +/// 1. `retry!(max_attempts, base_delay, max_delay, code)` +/// - Customizes the retry behavior: +/// - `max_attempts`: Maximum number of retry attempts. Set to `0` for infinite retries. +/// - `base_delay`: Initial delay (in milliseconds) before retrying. Delays increase exponentially. +/// - `max_delay`: Maximum delay (in milliseconds) between retries. +/// +/// 2. `retry!(max_attempts, code)` +/// - Retries the provided `code` up to `max_attempts` times, with a default base delay of `80ms` +/// and a maximum delay of `1000ms` between attempts. +/// +/// 3. `retry!(code)` +/// - Retries the provided `code` indefinitely until it succeeds, using a default base delay of `80ms` +/// and a maximum delay of `1000ms` between attempts. +/// +/// # Examples +/// +/// ## Example 1: Retry with a default delay configuration +/// ```rust,ignore +/// use anyhow::Result; +/// use usls::retry; +/// +/// fn main() -> Result<()> { +/// println!( +/// "{}", +/// retry!(3.9, { +/// Err::(anyhow::anyhow!("Failed message")) +/// })? +/// ); +/// Ok(()) +/// } +/// ``` +/// +/// ## Example 2: Retry until a random condition is met +/// ```rust +/// use anyhow::Result; +/// use usls::retry; +/// +/// fn main() -> Result<()> { +/// let _n = retry!({ +/// let n = rand::random::(); +/// if n < 0.7 { +/// Err(anyhow::anyhow!(format!("Random failure: {}", n))) +/// } else { +/// Ok(1) +/// } +/// })?; +/// Ok(()) +/// } +/// ``` +/// +/// ## Example 3: Retry with custom delays and a stateful condition +/// ```rust +/// use anyhow::Result; +/// use usls::retry; +/// +/// fn main() -> Result<()> { +/// let mut cnt = 5; +/// fn example_function(cnt: usize) -> Result { +/// if cnt < 10 { +/// anyhow::bail!("Failed") +/// } else { +/// Ok(42) +/// } +/// } +/// +/// println!( +/// "Result: {}", +/// retry!(20, 10, 100, { +/// cnt += 1; +/// example_function(cnt) +/// })? +/// ); +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! retry { + ($code:expr) => { + retry!(0, 80, 1000, $code) + }; + ($max_attempts:expr, $code:expr) => { + retry!($max_attempts, 80, 1000, $code) + }; + ($max_attempts:expr, $base_delay:expr, $max_delay:expr, $code:expr) => {{ + let max_attempts: u64 = ($max_attempts as f64).round() as u64; + let base_delay: u64 = ($base_delay as f64).round() as u64; + let max_delay: u64 = ($max_delay as f64).round() as u64; + if base_delay == 0 { + anyhow::bail!( + "[retry!] `base_delay` cannot be zero. Received: {}", + $base_delay + ); + } + if max_delay == 0 { + anyhow::bail!( + "[retry!] `max_delay` cannot be zero. Received: {}", + $max_delay + ); + } + if max_delay <= base_delay { + anyhow::bail!( + "[retry!] `max_delay`: {} must be greater than `base_delay`: {}.", + $base_delay, + $max_delay + ); + } + + let mut n = 1; + loop { + match $code { + Ok(result) => { + log::debug!("[retry!] Attempt {} succeeded.", n); + break Ok::<_, anyhow::Error>(result); + } + Err(err) => { + let message = format!( + "[retry!] Attempt {}/{} failed with error: {:?}.", + n, + if max_attempts == 0 { + "inf".to_string() + } else { + max_attempts.to_string() + }, + err, + ); + if max_attempts > 0 && n >= max_attempts { + log::error!("{} Stopping after {} attempts.", message, n); + anyhow::bail!(err); + } + + let delay = (base_delay * (1 << (n - 1))).min(max_delay); + let delay = std::time::Duration::from_millis(delay); + log::debug!("{} Retrying in {:?}..", message, delay); + std::thread::sleep(delay); + n += 1; + } + } + } + }}; +} diff --git a/src/misc/scale.rs b/src/misc/scale.rs new file mode 100644 index 0000000..4dc5ab4 --- /dev/null +++ b/src/misc/scale.rs @@ -0,0 +1,83 @@ +#[derive(Debug, Copy, Clone)] +pub enum Scale { + N, + T, + B, + S, + M, + L, + C, + E, + X, + G, + P, + A, + F, +} + +impl std::fmt::Display for Scale { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::N => "n", + Self::T => "t", + Self::S => "s", + Self::B => "b", + Self::M => "m", + Self::L => "l", + Self::C => "c", + Self::E => "e", + Self::X => "x", + Self::G => "g", + Self::P => "p", + Self::A => "a", + Self::F => "f", + }; + write!(f, "{}", x) + } +} + +impl TryFrom for Scale { + type Error = anyhow::Error; + + fn try_from(s: char) -> Result { + match s { + 'n' => Ok(Self::N), + 't' => Ok(Self::T), + 'b' => Ok(Self::B), + 's' => Ok(Self::S), + 'm' => Ok(Self::M), + 'l' => Ok(Self::L), + 'c' => Ok(Self::C), + 'e' => Ok(Self::E), + 'x' => Ok(Self::X), + 'g' => Ok(Self::G), + 'p' => Ok(Self::P), + 'a' => Ok(Self::A), + 'f' => Ok(Self::F), + x => anyhow::bail!("Unsupported model scale: {:?}", x), + } + } +} + +impl TryFrom<&str> for Scale { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "n" | "nano" => Ok(Self::N), + "t" | "tiny" => Ok(Self::T), + "b" | "base" => Ok(Self::B), + "s" | "small" => Ok(Self::S), + "m" | "medium" => Ok(Self::M), + "l" | "large" => Ok(Self::L), + "c" => Ok(Self::C), + "e" => Ok(Self::E), + "x" | "extra-large" => Ok(Self::X), + "g" | "giant" => Ok(Self::G), + "p" | "pico" => Ok(Self::P), + "a" | "atto" => Ok(Self::A), + "f" | "femto" => Ok(Self::F), + x => anyhow::bail!("Unsupported model scale: {:?}", x), + } + } +} diff --git a/src/core/task.rs b/src/misc/task.rs similarity index 75% rename from src/core/task.rs rename to src/misc/task.rs index 8090625..80e5c33 100644 --- a/src/core/task.rs +++ b/src/misc/task.rs @@ -1,7 +1,5 @@ -#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)] +#[derive(Debug, Copy, Clone, Ord, Eq, PartialOrd, PartialEq)] pub enum Task { - Untitled, - /// Image classification task. /// Input: image /// Output: a label representing the class of the image @@ -27,13 +25,14 @@ pub enum Task { /// Input: image /// Output: bounding boxes (bboxes), class labels, and optional scores for the detected objects ObjectDetection, + OrientedObjectDetection, + Obb, /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects. /// Input: image /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores /// Open set detection task, with String query - OpenSetDetection(String), - + OpenSetDetection(&'static str), /// Task for generating brief descriptions of dense regions in the image. /// Input: image /// Output: bounding boxes (bboxes), brief phrase labels, and optional scores for detected regions @@ -44,12 +43,16 @@ pub enum Task { /// Input: image /// Output: coordinates of detected keypoints KeypointsDetection, + Pose, /// Semantic segmentation task, segmenting the image into different semantic regions. /// Input: image /// Output: per-pixel class labels indicating object or background SemanticSegmentation, + ImageFeatureExtraction, + TextFeatureExtraction, + /// Instance segmentation task, detecting and segmenting individual object instances. /// Input: image /// Output: pixel masks for each object instance @@ -94,12 +97,12 @@ pub enum Task { /// Input: image and text /// Output: image region and the corresponding phrase /// caption to phrase grounding - CaptionToPhraseGrounding(String), + CaptionToPhraseGrounding(&'static str), /// Referring expression segmentation task, segmenting objects in the image based on a text description. /// Input: image and referring expression /// Output: a segmentation mask for the object referred to by the text - ReferringExpressionSegmentation(String), + ReferringExpressionSegmentation(&'static str), /// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM). /// Input: image and region proposals @@ -122,7 +125,7 @@ pub enum Task { /// Visual question answering (VQA) task, answering questions related to an image. /// Input: image and question text /// Output: the answer to the question - Vqa(String), + Vqa(&'static str), /// Optical character recognition (OCR) task, recognizing text in an image. /// Input: image @@ -135,10 +138,59 @@ pub enum Task { OcrWithRegion, } +impl std::fmt::Display for Task { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::ImageClassification => "image-classification", + Self::ObjectDetection => "object-detection", + Self::Pose => "pose", + Self::KeypointsDetection => "pose-detection", + Self::InstanceSegmentation => "instance-segmentation", + Self::Obb => "obb", + Self::OrientedObjectDetection => "oriented-object-detection", + Self::DepthEstimation => "depth", + Self::Caption(0) => "caption", + Self::Caption(1) => "detailed-caption", + Self::Caption(2) => "more-detailed-caption", + Self::ImageTagging => "image-tagging", + Self::Ocr => "ocr", + Self::OcrWithRegion => "ocr-with-region", + Self::Vqa(_) => "vqa", + _ => todo!(), + }; + write!(f, "{}", x) + } +} + +impl TryFrom<&str> for Task { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cls" | "classify" | "classification" => Ok(Self::ImageClassification), + "det" | "od" | "detect" => Ok(Self::ObjectDetection), + "kpt" | "pose" => Ok(Self::KeypointsDetection), + "seg" | "segment" => Ok(Self::InstanceSegmentation), + "obb" => Ok(Self::OrientedObjectDetection), + _ => todo!(), // x => anyhow::bail!("Unsupported model task: {}", x), + } + } +} + impl Task { + pub fn yolo_str(&self) -> &'static str { + match self { + Self::ImageClassification => "cls", + Self::ObjectDetection => "det", + Self::Pose | Self::KeypointsDetection => "pose", + Self::InstanceSegmentation => "seg", + Self::Obb | Self::OrientedObjectDetection => "obb", + x => unimplemented!("Unsupported YOLO Task: {}", x), + } + } + pub fn prompt_for_florence2(&self) -> anyhow::Result { let prompt = match self { - Self::Untitled => anyhow::bail!("No task specified."), Self::Caption(0) => "What does the image describe?".to_string(), Self::Caption(1) => "Describe in detail what is shown in the image.".to_string(), Self::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(), @@ -178,7 +230,7 @@ impl Task { x0, y0, x1, y1 ) } - _ => anyhow::bail!("Unsupported task."), + x => anyhow::bail!("Unsupported Florence2 task: {:?}", x), }; Ok(prompt) diff --git a/src/misc/ts.rs b/src/misc/ts.rs new file mode 100644 index 0000000..dbc55c3 --- /dev/null +++ b/src/misc/ts.rs @@ -0,0 +1,393 @@ +use std::collections::HashMap; +use std::time::Duration; + +/// A macro to measure the execution time of a given code block and optionally log the result. +#[macro_export] +macro_rules! elapsed { + ($code:expr) => {{ + let t = std::time::Instant::now(); + let ret = $code; + let duration = t.elapsed(); + (duration, ret) + }}; + ($label:expr, $ts:expr, $code:expr) => {{ + let t = std::time::Instant::now(); + let ret = $code; + let duration = t.elapsed(); + $ts.push($label, duration); + ret + }}; +} + +#[derive(aksr::Builder, Debug, Default, Clone, PartialEq)] +pub struct Ts { + // { k1: [d1,d1,d1,..], k2: [d2,d2,d2,..], k3: [d3,d3,d3,..], ..} + map: HashMap>, + names: Vec, +} + +impl std::ops::Index<&str> for Ts { + type Output = Vec; + + fn index(&self, index: &str) -> &Self::Output { + self.map.get(index).expect("Index was not found in `Ts`") + } +} + +impl std::ops::Index for Ts { + type Output = Vec; + + fn index(&self, index: usize) -> &Self::Output { + self.names + .get(index) + .and_then(|key| self.map.get(key)) + .expect("Index was not found in `Ts`") + } +} + +impl Ts { + pub fn summary(&self) { + let decimal_places = 4; + let place_holder = '-'; + let width_count = 10; + let width_time = 15; + let width_task = self + .names + .iter() + .map(|s| s.len()) + .max() + .map(|x| x + 8) + .unwrap_or(60); + + let sep = "-".repeat(width_task + 66); + + // cols + println!( + "\n\n{: Self { + let mut names = Vec::new(); + let mut map: HashMap> = HashMap::new(); + for x in xs.iter() { + names.extend_from_slice(x.names()); + map.extend(x.map().to_owned()); + } + + Self { names, map } + } + + pub fn push(&mut self, k: &str, v: Duration) { + if !self.names.contains(&k.to_string()) { + self.names.push(k.to_string()); + } + self.map + .entry(k.to_string()) + .and_modify(|x| x.push(v)) + .or_insert(vec![v]); + } + + pub fn numit(&self) -> anyhow::Result { + // num of iterations + if self.names.is_empty() { + anyhow::bail!("Empty Ts"); + } + + let len = self[0].len(); + for v in self.map.values() { + if v.len() != len { + anyhow::bail!( + "Invalid Ts: The number of elements in each values entry is inconsistent" + ); + } + } + + Ok(len) + } + + pub fn is_valid(&self) -> bool { + let mut iter = self.map.values(); + if let Some(first) = iter.next() { + let len = first.len(); + iter.all(|v| v.len() == len) + } else { + true + } + } + + pub fn sum_by_index(&self, i: usize) -> Duration { + self[i].iter().sum::() + } + + pub fn sum_by_key(&self, i: &str) -> Duration { + self[i].iter().sum::() + } + + pub fn avg_by_index(&self, i: usize) -> anyhow::Result { + let len = self[i].len(); + if len == 0 { + anyhow::bail!("Cannot compute average for an empty duration vector.") + } else { + Ok(self.sum_by_index(i) / len as u32) + } + } + + pub fn avg_by_key(&self, i: &str) -> anyhow::Result { + let len = self[i].len(); + if len == 0 { + anyhow::bail!("Cannot compute average for an empty duration vector.") + } else { + Ok(self.sum_by_key(i) / len as u32) + } + } + + pub fn sum_column(&self, i: usize) -> Duration { + self.map + .values() + .filter_map(|vec| vec.get(i)) + .copied() + .sum() + } + + pub fn sum(&self) -> Duration { + self.map.values().flat_map(|vec| vec.iter()).copied().sum() + } + + pub fn avg(&self) -> anyhow::Result { + self.names.iter().map(|x| self.avg_by_key(x)).sum() + } + + pub fn skip(mut self, n: usize) -> Self { + self.map.iter_mut().for_each(|(_, vec)| { + *vec = vec.iter().skip(n).copied().collect(); + }); + self + } + + pub fn clear(&mut self) { + self.names.clear(); + self.map.clear(); + } + + pub fn is_empty(&self) -> bool { + self.names.is_empty() && self.map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_push_and_indexing() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts["task1"], vec![Duration::new(1, 0), Duration::new(2, 0)]); + assert_eq!(ts["task2"], vec![Duration::new(3, 0)]); + } + + #[test] + fn test_numit() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + + assert_eq!(ts.numit().unwrap(), 2); + } + + #[test] + fn test_is_valid() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert!(!ts.is_valid()); + + ts.push("task2", Duration::new(4, 0)); + ts.push("task3", Duration::new(5, 0)); + + assert!(!ts.is_valid()); + + ts.push("task3", Duration::new(6, 0)); + assert!(ts.is_valid()); + } + + #[test] + fn test_sum_by_index() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_by_index(0), Duration::new(3, 0)); // 1 + 2 + assert_eq!(ts.sum_by_index(1), Duration::new(9, 0)); // 1 + 2 + } + + #[test] + fn test_sum_by_key() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_by_key("task1"), Duration::new(3, 0)); // 1 + 2 + assert_eq!(ts.sum_by_key("task2"), Duration::new(9, 0)); // 1 + 2 + } + + #[test] + fn test_avg_by_index() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(2, 0)); + ts.push("task2", Duration::new(2, 0)); + ts.push("task3", Duration::new(2, 0)); + + assert_eq!(ts.avg_by_index(0).unwrap(), Duration::new(1, 500_000_000)); + assert_eq!(ts.avg_by_index(1).unwrap(), Duration::new(2, 0)); + assert_eq!(ts.avg_by_index(2).unwrap(), Duration::new(2, 0)); + } + + #[test] + fn test_avg_by_key() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + + let avg = ts.avg_by_key("task1").unwrap(); + assert_eq!(avg, Duration::new(1, 500_000_000)); + } + + #[test] + fn test_sum_column() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_column(0), Duration::new(4, 0)); // 1 + 3 + } + + #[test] + fn test_sum() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum(), Duration::new(6, 0)); + } + + #[test] + fn test_avg() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + + assert_eq!(ts.avg().unwrap(), Duration::new(5, 0)); + } + + #[test] + fn test_skip() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + ts.push("task2", Duration::new(4, 0)); + + let ts_skipped = ts.skip(1); + + assert_eq!(ts_skipped["task1"], vec![Duration::new(2, 0)]); + assert_eq!( + ts_skipped["task2"], + vec![Duration::new(4, 0), Duration::new(4, 0)] + ); + + let ts_skipped = ts_skipped.skip(1); + + assert!(ts_skipped["task1"].is_empty()); + assert_eq!(ts_skipped["task2"], vec![Duration::new(4, 0)]); + } + + #[test] + fn test_clear() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task2", Duration::new(2, 0)); + + ts.clear(); + assert!(ts.names.is_empty()); + assert!(ts.map.is_empty()); + } +} diff --git a/src/utils/mod.rs b/src/misc/utils.rs similarity index 57% rename from src/utils/mod.rs rename to src/misc/utils.rs index 543d13c..c243618 100644 --- a/src/utils/mod.rs +++ b/src/misc/utils.rs @@ -3,17 +3,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; -pub mod colormap256; -pub mod names; -mod quantizer; - -pub use colormap256::*; -pub use names::*; -pub use quantizer::Quantizer; - -pub(crate) const CHECK_MARK: &str = "✅"; -pub(crate) const CROSS_MARK: &str = "❌"; -pub(crate) const SAFE_CROSS_MARK: &str = "❎"; +pub(crate) const PREFIX_LENGTH: usize = 12; pub(crate) const NETWORK_PREFIXES: &[&str] = &[ "http://", "https://", "ftp://", "ftps://", "sftp://", "rtsp://", "mms://", "mmsh://", "rtmp://", "rtmps://", "file://", @@ -27,35 +17,54 @@ pub(crate) const STREAM_PROTOCOLS: &[&str] = &[ "rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://", "http://", "https://", ]; pub(crate) const PROGRESS_BAR_STYLE_CYAN: &str = - "{prefix:.cyan.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; + "{prefix:>12.cyan.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; pub(crate) const PROGRESS_BAR_STYLE_GREEN: &str = - "{prefix:.green.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; + "{prefix:>12.green.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; pub(crate) const PROGRESS_BAR_STYLE_CYAN_2: &str = - "{prefix:.cyan.bold} {human_pos}/{human_len} |{bar}| {msg}"; + "{prefix:>12.cyan.bold} {human_pos}/{human_len} |{bar}| {msg}"; pub(crate) const PROGRESS_BAR_STYLE_CYAN_3: &str = - "{prefix:.cyan.bold} |{bar}| {human_pos}/{human_len} {msg}"; + "{prefix:>12.cyan.bold} |{bar}| {human_pos}/{human_len} {msg}"; pub(crate) const PROGRESS_BAR_STYLE_GREEN_2: &str = - "{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; + "{prefix:>12.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; pub(crate) const PROGRESS_BAR_STYLE_FINISH: &str = - "{prefix:.green.bold} {msg} for {human_len} iterations in {elapsed}"; + "{prefix:>12.green.bold} {msg} for {human_len} iterations in {elapsed}"; pub(crate) const PROGRESS_BAR_STYLE_FINISH_2: &str = - "{prefix:.green.bold} {msg} x{human_len} in {elapsed}"; + "{prefix:>12.green.bold} {msg} x{human_len} in {elapsed}"; pub(crate) const PROGRESS_BAR_STYLE_FINISH_3: &str = - "{prefix:.green.bold} {msg} ({binary_total_bytes}) in {elapsed}"; -pub(crate) const PROGRESS_BAR_STYLE_FINISH_4: &str = "{prefix:.green.bold} {msg} in {elapsed}"; + "{prefix:>12.green.bold} {msg} ({binary_total_bytes}) in {elapsed}"; +pub(crate) const PROGRESS_BAR_STYLE_FINISH_4: &str = "{prefix:>12.green.bold} {msg} in {elapsed}"; + +pub(crate) fn try_fetch_stem>(p: P) -> anyhow::Result { + let p = p.as_ref(); + let stem = p + .file_stem() + .ok_or(anyhow::anyhow!( + "Failed to get the `file_stem` of `model_file`: {:?}", + p + ))? + .to_str() + .ok_or(anyhow::anyhow!("Failed to convert from `&OsStr` to `&str`"))?; + + Ok(stem.to_string()) +} + +pub fn human_bytes(size: f64, use_binary: bool) -> String { + let units = if use_binary { + ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"] + } else { + ["B", "KB", "MB", "GB", "TB", "PB", "EB"] + }; -pub fn human_bytes(size: f64) -> String { - let units = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; let mut size = size; let mut unit_index = 0; - let k = 1024.; + let k = if use_binary { 1024. } else { 1000. }; while size >= k && unit_index < units.len() - 1 { size /= k; unit_index += 1; } - format!("{:.1} {}", size, units[unit_index]) + format!("{:.2} {}", size, units[unit_index]) } pub(crate) fn string_random(n: usize) -> String { @@ -75,7 +84,7 @@ pub(crate) fn string_now(delimiter: &str) -> String { t_now.format(&fmt).to_string() } -pub fn build_progress_bar( +pub(crate) fn build_progress_bar( n: u64, prefix: &str, msg: Option<&str>, @@ -83,7 +92,7 @@ pub fn build_progress_bar( ) -> anyhow::Result { let pb = ProgressBar::new(n); pb.set_style(ProgressStyle::with_template(style_temp)?.progress_chars("██ ")); - pb.set_prefix(prefix.to_string()); + pb.set_prefix(format!("{:>PREFIX_LENGTH$}", prefix)); pb.set_message(msg.unwrap_or_default().to_string()); Ok(pb) diff --git a/src/misc/version.rs b/src/misc/version.rs new file mode 100644 index 0000000..022f39e --- /dev/null +++ b/src/misc/version.rs @@ -0,0 +1,43 @@ +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, Default)] +pub struct Version(pub u8, pub u8); + +impl std::fmt::Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = if self.1 == 0 { + format!("v{}", self.0) + } else { + format!("v{}.{}", self.0, self.1) + }; + write!(f, "{}", x) + } +} + +impl From<(u8, u8)> for Version { + fn from((x, y): (u8, u8)) -> Self { + Self(x, y) + } +} + +impl From for Version { + fn from(x: f32) -> Self { + let x = format!("{:?}", x); + let x: Vec = x + .as_str() + .split('.') + .map(|x| x.parse::().unwrap_or(0)) + .collect(); + Self(x[0], x[1]) + } +} + +impl From for Version { + fn from(x: u8) -> Self { + Self(x, 0) + } +} + +impl Version { + pub fn new(x: u8, y: u8) -> Self { + Self(x, y) + } +} diff --git a/src/core/viewer.rs b/src/misc/viewer.rs similarity index 92% rename from src/core/viewer.rs rename to src/misc/viewer.rs index 982fc8a..cb37c77 100644 --- a/src/core/viewer.rs +++ b/src/misc/viewer.rs @@ -1,13 +1,12 @@ use anyhow::Result; use image::DynamicImage; +use log::info; use minifb::{Window, WindowOptions}; use video_rs::{ encode::{Encoder, Settings}, time::Time, }; -use crate::{string_now, Dir, Key}; - pub struct Viewer<'a> { name: &'a str, window: Option, @@ -107,8 +106,9 @@ impl Viewer<'_> { let (w, h) = frame.dimensions(); if self.writer.is_none() { let settings = Settings::preset_h264_yuv420p(w as _, h as _, false); - let saveout = Dir::saveout(&["runs"])?.join(format!("{}.mp4", string_now("-"))); - tracing::info!("Video will be save to: {:?}", saveout); + let saveout = + crate::Dir::saveout(&["runs"])?.join(format!("{}.mp4", crate::string_now("-"))); + info!("Video will be save to: {:?}", saveout); self.writer = Some(Encoder::new(saveout, settings)?); } @@ -138,7 +138,7 @@ impl Viewer<'_> { match &mut self.writer { Some(writer) => writer.finish()?, None => { - tracing::info!("Found no video writer. No need to release."); + info!("Found no video writer. No need to release."); } } Ok(()) @@ -152,7 +152,7 @@ impl Viewer<'_> { } } - pub fn is_key_pressed(&self, key: Key) -> bool { + pub fn is_key_pressed(&self, key: crate::Key) -> bool { if let Some(window) = &self.window { window.is_key_down(key) } else { @@ -161,7 +161,7 @@ impl Viewer<'_> { } pub fn is_esc_pressed(&self) -> bool { - self.is_key_pressed(Key::Escape) + self.is_key_pressed(crate::Key::Escape) } pub fn resizable(mut self, x: bool) -> Self { diff --git a/src/models/beit/README.md b/src/models/beit/README.md new file mode 100644 index 0000000..5fe7073 --- /dev/null +++ b/src/models/beit/README.md @@ -0,0 +1,12 @@ +# BEiT: BERT Pre-Training of Image Transformers + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/microsoft/unilm/tree/master/beit) + + +## Example + +Refer to the [example](../../../examples/beit) + + diff --git a/src/models/beit/config.rs b/src/models/beit/config.rs new file mode 100644 index 0000000..41341d8 --- /dev/null +++ b/src/models/beit/config.rs @@ -0,0 +1,26 @@ +use crate::IMAGENET_NAMES_1K; + +/// Model configuration for `BEiT` +impl crate::Options { + pub fn beit() -> Self { + Self::default() + .with_model_name("beit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn beit_base() -> Self { + Self::beit().with_model_file("b.onnx") + } + + pub fn beit_large() -> Self { + Self::beit().with_model_file("l.onnx") + } +} diff --git a/src/models/beit/mod.rs b/src/models/beit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/beit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/blip.rs b/src/models/blip.rs deleted file mode 100644 index 72c8925..0000000 --- a/src/models/blip.rs +++ /dev/null @@ -1,155 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::s; -use std::io::Write; -use tokenizers::Tokenizer; - -use crate::{ - Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y, -}; - -#[derive(Debug)] -pub struct Blip { - pub textual: OrtEngine, - pub visual: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch_visual: MinOptMax, - pub batch_textual: MinOptMax, - tokenizer: TokenizerStream, -} - -impl Blip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let mut visual = OrtEngine::new(&options_visual)?; - let mut textual = OrtEngine::new(&options_textual)?; - let (batch_visual, batch_textual, height, width) = ( - visual.batch().to_owned(), - textual.batch().to_owned(), - visual.height().to_owned(), - visual.width().to_owned(), - ); - - let tokenizer = options_textual - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - let tokenizer = TokenizerStream::new(tokenizer); - visual.dry_run()?; - textual.dry_run()?; - Ok(Self { - textual, - visual, - batch_visual, - batch_textual, - height, - width, - tokenizer, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.visual.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) - } - - pub fn caption(&mut self, xs: &Y, prompt: Option<&str>, show: bool) -> Result> { - let mut ys: Vec = Vec::new(); - let image_embeds = match xs.embedding() { - Some(x) => X::from(x.data().to_owned()), - None => anyhow::bail!("No image embeddings found."), - }; - let image_embeds_attn_mask = X::ones(&[self.batch_visual(), image_embeds.dims()[1]]); - - let mut y_text = String::new(); - - // conditional - let mut input_ids = match prompt { - None => { - if show { - print!("[Unconditional]: "); - } - vec![0.0f32] - } - Some(prompt) => { - let encodings = match self.tokenizer.tokenizer().encode(prompt, false) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); - if show { - print!("[Conditional]: {} ", prompt); - } - y_text.push_str(&format!("{} ", prompt)); - ids - } - }; - - let mut logits_sampler = LogitsSampler::new(); - loop { - let input_ids_nd = X::from(input_ids.to_owned()) - .insert_axis(0)? - .repeat(0, self.batch_textual())?; - let input_ids_attn_mask = X::ones(input_ids_nd.dims()); - - let y = self.textual.run(Xs::from(vec![ - input_ids_nd, - input_ids_attn_mask, - image_embeds.clone(), - image_embeds_attn_mask.clone(), - ]))?; // N, length, vocab_size - let y = y[0].slice(s!(0, -1.., ..)); - let logits = y.slice(s!(0, ..)).to_vec(); - let token_id = logits_sampler.decode(&logits)?; - input_ids.push(token_id as f32); - - // SEP - if token_id == 102 { - break; - } - - // streaming generation - if let Some(t) = self.tokenizer.next_token(token_id as u32)? { - y_text.push_str(&t); - if show { - print!("{t}"); - // std::thread::sleep(std::time::Duration::from_millis(5)); - } - std::io::stdout().flush()?; - } - } - if show { - println!(); - } - self.tokenizer.clear(); - ys.push(Y::default().with_texts(&[y_text])); - Ok(ys) - } - - pub fn batch_visual(&self) -> usize { - self.batch_visual.opt() - } - - pub fn batch_textual(&self) -> usize { - self.batch_textual.opt() - } -} diff --git a/src/models/blip/README.md b/src/models/blip/README.md new file mode 100644 index 0000000..2585e4d --- /dev/null +++ b/src/models/blip/README.md @@ -0,0 +1,18 @@ +# BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/salesforce/BLIP) + +## TODO + +- [x] Image-Text Captioning +- [ ] Visual Question Answering (VQA) +- [ ] Image-Text Retrieval +- [ ] TensorRT Support for Textual Model + +## Example + +Refer to the [example](../../../examples/blip) + + diff --git a/src/models/blip/config.rs b/src/models/blip/config.rs new file mode 100644 index 0000000..2248a9e --- /dev/null +++ b/src/models/blip/config.rs @@ -0,0 +1,34 @@ +/// Model configuration for `BLIP` +impl crate::Options { + pub fn blip() -> Self { + Self::default().with_model_name("blip").with_batch_size(1) + } + + #[allow(clippy::excessive_precision)] + pub fn blip_visual() -> Self { + Self::blip() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 2, 384.into()) + .with_model_ixx(0, 3, 384.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.26130258, 0.27577711]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn blip_textual() -> Self { + Self::blip().with_model_kind(crate::Kind::Language) + } + + pub fn blip_v1_base_caption_visual() -> Self { + Self::blip_visual() + .with_model_version(1.0.into()) + .with_model_file("v1-base-caption-visual.onnx") + } + + pub fn blip_v1_base_caption_textual() -> Self { + Self::blip_textual() + .with_model_version(1.0.into()) + .with_model_file("v1-base-caption-textual.onnx") + } +} diff --git a/src/models/blip/impl.rs b/src/models/blip/impl.rs new file mode 100644 index 0000000..295fe5f --- /dev/null +++ b/src/models/blip/impl.rs @@ -0,0 +1,130 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual}, + LogitsSampler, Options, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Builder)] +pub struct Blip { + visual: BaseModelVisual, + textual: BaseModelTextual, + ts: Ts, + max_length: usize, + eos_token_id: u32, +} + +impl Blip { + pub fn new(options_visual: Options, options_textual: Options) -> Result { + let visual = BaseModelVisual::new(options_visual)?; + let textual = BaseModelTextual::new(options_textual)?; + let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); + let max_length = 512; + let eos_token_id = 102; + + Ok(Self { + textual, + visual, + ts, + max_length, + eos_token_id, + }) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + self.visual.encode(xs) + } + + pub fn encode_texts(&mut self, text: Option<&str>) -> Result>> { + let input_ids = self + .textual + .processor() + .encode_text_ids(text.unwrap_or_default(), false)?; + Ok(vec![input_ids.clone(); self.batch()]) + } + + pub fn forward(&mut self, images: &[DynamicImage], text: Option<&str>) -> Result { + let image_embeds = elapsed!("encode_images", self.ts, { self.encode_images(images)? }); + let ys = elapsed!("generate", self.ts, { self.generate(&image_embeds, text)? }); + + Ok(ys) + } + + pub fn generate(&mut self, image_embeds: &X, text: Option<&str>) -> Result { + // encode texts + let mut token_ids = self.encode_texts(text)?; + + // generate + let logits_sampler = LogitsSampler::new(); + let mut finished = vec![false; self.batch()]; + for _ in 0..self.max_length { + let input_ids_nd = token_ids + .iter() + .map(|tokens| X::from(tokens.clone()).insert_axis(0)) + .collect::, _>>()?; + + let input_ids_nd = X::concat(&input_ids_nd, 0)?; + let input_ids_attn_mask = X::ones(input_ids_nd.dims()); + + // decode + let outputs = self.textual.inference(Xs::from(vec![ + input_ids_nd, + input_ids_attn_mask, + image_embeds.clone(), + X::ones(&[self.visual().batch(), image_embeds.dims()[1]]), // image_embeds_attn_mask + ]))?; + + // decode each token for each batch + for (i, logit) in outputs[0].axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + if token_id == self.eos_token_id { + finished[i] = true; + } + token_ids[i].push(token_id as f32); + } else { + token_ids[i].push(self.eos_token_id as f32); + } + } + + if finished.iter().all(|&x| x) { + break; + } + } + + // batch decode + let texts = self.textual.processor().decode_tokens_batch( + &token_ids + .into_iter() + .map(|v| v.into_iter().map(|x| x as u32).collect::>()) + .collect::>>(), + true, + )?; + + let ys = texts + .into_iter() + .map(|x| Y::default().with_texts(&[x.into()])) + .collect::>() + .into(); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn batch(&self) -> usize { + self.visual.batch() as _ + } +} diff --git a/src/models/blip/mod.rs b/src/models/blip/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/blip/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/clip.rs b/src/models/clip.rs deleted file mode 100644 index f9bdee6..0000000 --- a/src/models/clip.rs +++ /dev/null @@ -1,107 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Array2; -use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; - -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct Clip { - pub textual: OrtEngine, - pub visual: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch_visual: MinOptMax, - pub batch_textual: MinOptMax, - tokenizer: Tokenizer, - context_length: usize, -} - -impl Clip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let context_length = 77; - let mut visual = OrtEngine::new(&options_visual)?; - let mut textual = OrtEngine::new(&options_textual)?; - let (batch_visual, batch_textual, height, width) = ( - visual.inputs_minoptmax()[0][0].to_owned(), - textual.inputs_minoptmax()[0][0].to_owned(), - visual.inputs_minoptmax()[0][2].to_owned(), - visual.inputs_minoptmax()[0][3].to_owned(), - ); - - let tokenizer = options_textual - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - - let mut tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::Fixed(context_length), - direction: PaddingDirection::Right, - pad_to_multiple_of: None, - pad_id: 0, - pad_type_id: 0, - pad_token: "[PAD]".to_string(), - })); - - visual.dry_run()?; - textual.dry_run()?; - - Ok(Self { - textual, - visual, - batch_visual, - batch_textual, - height, - width, - tokenizer, - context_length, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.visual.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) - } - - pub fn encode_texts(&mut self, texts: &[String]) -> Result { - let encodings = match self.tokenizer.encode_batch(texts.to_owned(), false) { - Err(err) => anyhow::bail!("{:?}", err), - Ok(x) => x, - }; - let xs: Vec = encodings - .iter() - .flat_map(|i| i.get_ids().iter().map(|&b| b as f32)) - .collect(); - let xs = Array2::from_shape_vec((texts.len(), self.context_length), xs)?.into_dyn(); - let xs = X::from(xs); - let ys = self.textual.run(Xs::from(xs))?; - Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) - } - - pub fn batch_visual(&self) -> usize { - self.batch_visual.opt() - } - - pub fn batch_textual(&self) -> usize { - self.batch_textual.opt() - } -} diff --git a/src/models/clip/README.md b/src/models/clip/README.md new file mode 100644 index 0000000..8bc962e --- /dev/null +++ b/src/models/clip/README.md @@ -0,0 +1,12 @@ +# CLIP + +CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/openai/CLIP) + + +## Example + +Refer to the [example](../../../examples/clip) diff --git a/src/models/clip/config.rs b/src/models/clip/config.rs new file mode 100644 index 0000000..0454261 --- /dev/null +++ b/src/models/clip/config.rs @@ -0,0 +1,71 @@ +use crate::Kind; + +/// Model configuration for `CLIP` +impl crate::Options { + pub fn clip() -> Self { + Self::default() + .with_model_name("clip") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn clip_visual() -> Self { + Self::clip() + .with_model_kind(Kind::Vision) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) + } + + pub fn clip_textual() -> Self { + Self::clip() + .with_model_kind(Kind::Language) + .with_model_max_length(77) + } + + pub fn clip_vit_b16_visual() -> Self { + Self::clip_visual().with_model_file("vit-b16-visual.onnx") + } + + pub fn clip_vit_b16_textual() -> Self { + Self::clip_textual().with_model_file("vit-b16-textual.onnx") + } + + pub fn clip_vit_b32_visual() -> Self { + Self::clip_visual().with_model_file("vit-b32-visual.onnx") + } + + pub fn clip_vit_b32_textual() -> Self { + Self::clip_textual().with_model_file("vit-b32-textual.onnx") + } + + pub fn clip_vit_l14_visual() -> Self { + Self::clip_visual().with_model_file("vit-l14-visual.onnx") + } + + pub fn clip_vit_l14_textual() -> Self { + Self::clip_textual().with_model_file("vit-l14-textual.onnx") + } + + pub fn jina_clip_v1() -> Self { + Self::default() + .with_model_name("jina-clip-v1") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn jina_clip_v1_visual() -> Self { + Self::jina_clip_v1() + .with_model_kind(Kind::Vision) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) + .with_model_file("visual.onnx") + } + + pub fn jina_clip_v1_textual() -> Self { + Self::jina_clip_v1() + .with_model_kind(Kind::Language) + .with_model_file("textual.onnx") + } +} diff --git a/src/models/clip/impl.rs b/src/models/clip/impl.rs new file mode 100644 index 0000000..24f6abb --- /dev/null +++ b/src/models/clip/impl.rs @@ -0,0 +1,149 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Array2; + +use crate::{elapsed, Engine, Options, Processor, Ts, Xs, X}; + +#[derive(Debug, Builder)] +pub struct ClipVisual { + engine: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + ts: Ts, +} + +impl ClipVisual { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&224.into()).opt(), + engine.try_width().unwrap_or(&224.into()).opt(), + engine.ts.clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + processor, + ts, + }) + } + + pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} + +#[derive(Debug, Builder)] +pub struct ClipTextual { + engine: Engine, + batch: usize, + processor: Processor, + ts: Ts, +} + +impl ClipTextual { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, ts) = (engine.batch().opt(), engine.ts.clone()); + let processor = options.to_processor()?; + + Ok(Self { + engine, + batch, + processor, + ts, + }) + } + + pub fn preprocess(&self, xs: &[&str]) -> Result { + let encodings: Vec = self + .processor + .encode_texts_ids(xs, false)? // skip_special_tokens + .into_iter() + .flatten() + .collect(); + + let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)? + .into_dyn() + .into(); + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_texts(&mut self, xs: &[&str]) -> Result { + let xs = elapsed!("textual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("textual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("textual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} + +#[derive(Debug, Builder)] +pub struct Clip { + textual: ClipTextual, + visual: ClipVisual, + ts: Ts, +} + +impl Clip { + pub fn new(options_visual: Options, options_textual: Options) -> Result { + let visual = ClipVisual::new(options_visual)?; + let textual = ClipTextual::new(options_textual)?; + // let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); + let ts = Ts::default(); + + Ok(Self { + textual, + visual, + ts, + }) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let x = elapsed!("encode_images", self.ts, { self.visual.encode_images(xs)? }); + Ok(x) + } + + pub fn encode_texts(&mut self, xs: &[&str]) -> Result { + let x = elapsed!("encode_texts", self.ts, { self.textual.encode_texts(xs)? }); + Ok(x) + } + + pub fn summary(&mut self) { + // self.ts.clear(); + // self.ts = Ts::merge(&[&self.ts, self.visual.ts(), self.textual.ts()]); + self.ts.summary(); + self.visual.ts().summary(); + self.textual.ts().summary(); + } +} diff --git a/src/models/clip/mod.rs b/src/models/clip/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/clip/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/convnext/README.md b/src/models/convnext/README.md new file mode 100644 index 0000000..2975866 --- /dev/null +++ b/src/models/convnext/README.md @@ -0,0 +1,9 @@ +# ConvNeXt: A ConvNet for the 2020s + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/facebookresearch/ConvNeXt) + +## Example + +Refer to the [example](../../../examples/convnext/) diff --git a/src/models/convnext/config.rs b/src/models/convnext/config.rs new file mode 100644 index 0000000..a917513 --- /dev/null +++ b/src/models/convnext/config.rs @@ -0,0 +1,66 @@ +use crate::IMAGENET_NAMES_1K; + +/// Model configuration for `ConvNeXt` +impl crate::Options { + pub fn convnext() -> Self { + Self::default() + .with_model_name("convnext") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn convnext_v1_tiny() -> Self { + Self::convnext().with_model_file("v1-t.onnx") + } + + pub fn convnext_v1_small() -> Self { + Self::convnext().with_model_file("v1-s.onnx") + } + + pub fn convnext_v1_base() -> Self { + Self::convnext().with_model_file("v1-b.onnx") + } + + pub fn convnext_v1_large() -> Self { + Self::convnext().with_model_file("v1-l.onnx") + } + + pub fn convnext_v2_atto() -> Self { + Self::convnext().with_model_file("v2-a.onnx") + } + + pub fn convnext_v2_femto() -> Self { + Self::convnext().with_model_file("v2-f.onnx") + } + + pub fn convnext_v2_pico() -> Self { + Self::convnext().with_model_file("v2-p.onnx") + } + + pub fn convnext_v2_nano() -> Self { + Self::convnext().with_model_file("v2-n.onnx") + } + + pub fn convnext_v2_tiny() -> Self { + Self::convnext().with_model_file("v2-t.onnx") + } + + pub fn convnext_v2_small() -> Self { + Self::convnext().with_model_file("v2-s.onnx") + } + + pub fn convnext_v2_base() -> Self { + Self::convnext().with_model_file("v2-b.onnx") + } + + pub fn convnext_v2_large() -> Self { + Self::convnext().with_model_file("v2-l.onnx") + } +} diff --git a/src/models/convnext/mod.rs b/src/models/convnext/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/convnext/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/d_fine/README.md b/src/models/d_fine/README.md new file mode 100644 index 0000000..5c07b7d --- /dev/null +++ b/src/models/d_fine/README.md @@ -0,0 +1,9 @@ +# D-FINE: Redefine Regression Task of DETRs as Fine‑grained Distribution Refinement + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/manhbd-22022602/D-FINE) + +## Example + +Refer to the [example](../../../examples/d-fine) diff --git a/src/models/d_fine/config.rs b/src/models/d_fine/config.rs new file mode 100644 index 0000000..ea2cea7 --- /dev/null +++ b/src/models/d_fine/config.rs @@ -0,0 +1,42 @@ +/// Model configuration for `d_fine` +impl crate::Options { + pub fn d_fine() -> Self { + Self::rtdetr().with_model_name("d-fine") + } + + pub fn d_fine_n_coco() -> Self { + Self::d_fine().with_model_file("n-coco.onnx") + } + + pub fn d_fine_s_coco() -> Self { + Self::d_fine().with_model_file("s-coco.onnx") + } + + pub fn d_fine_m_coco() -> Self { + Self::d_fine().with_model_file("m-coco.onnx") + } + + pub fn d_fine_l_coco() -> Self { + Self::d_fine().with_model_file("l-coco.onnx") + } + + pub fn d_fine_x_coco() -> Self { + Self::d_fine().with_model_file("x-coco.onnx") + } + + pub fn d_fine_s_coco_obj365() -> Self { + Self::d_fine().with_model_file("s-obj2coco.onnx") + } + + pub fn d_fine_m_coco_obj365() -> Self { + Self::d_fine().with_model_file("m-obj2coco.onnx") + } + + pub fn d_fine_l_coco_obj365() -> Self { + Self::d_fine().with_model_file("l-obj2coco.onnx") + } + + pub fn d_fine_x_coco_obj365() -> Self { + Self::d_fine().with_model_file("x-obj2coco.onnx") + } +} diff --git a/src/models/d_fine/mod.rs b/src/models/d_fine/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/d_fine/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/db/README.md b/src/models/db/README.md new file mode 100644 index 0000000..9d1d2b5 --- /dev/null +++ b/src/models/db/README.md @@ -0,0 +1,14 @@ +# DB: Real-time Scene Text Detection with Differentiable Binarization + +It presents a real-time arbitrary-shape scene text detector, achieving the state-of-the-art performance on standard benchmarks. + +## Official Repository + +The official repository can be found on: + +- [DB](https://github.com/MhLiao/DB) +- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) + +## Example + +Refer to the [example](../../../examples/db) diff --git a/src/models/db/config.rs b/src/models/db/config.rs new file mode 100644 index 0000000..9d5b960 --- /dev/null +++ b/src/models/db/config.rs @@ -0,0 +1,66 @@ +/// Model configuration for [DB](https://github.com/MhLiao/DB) and [PaddleOCR-Det](https://github.com/PaddlePaddle/PaddleOCR) +impl crate::Options { + pub fn db() -> Self { + Self::default() + .with_model_name("db") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, (608, 960, 1600).into()) + .with_model_ixx(0, 3, (608, 960, 1600).into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_normalize(true) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_binary_thresh(0.2) + .with_class_confs(&[0.35]) + .with_min_width(5.0) + .with_min_height(12.0) + } + + pub fn ppocr_det_v3_ch() -> Self { + Self::db().with_model_file("ppocr-v3-ch.onnx") + } + + pub fn ppocr_det_v4_ch() -> Self { + Self::db().with_model_file("ppocr-v4-ch.onnx") + } + + pub fn ppocr_det_v4_server_ch() -> Self { + Self::db().with_model_file("ppocr-v4-server-ch.onnx") + } + + pub fn db2() -> Self { + Self::db() + .with_image_mean(&[0.798, 0.785, 0.772]) + .with_image_std(&[0.264, 0.2749, 0.287]) + // .with_binary_thresh(0.3) + // .with_class_confs(&[0.1]) + } + + pub fn db_mobilenet_v3_large() -> Self { + Self::db2().with_model_file("felixdittrich92-mobilenet-v3.onnx") + } + + pub fn db_mobilenet_v3_large_u8() -> Self { + Self::db2() + .with_model_file("https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx") + } + + pub fn db_resnet34() -> Self { + Self::db2().with_model_file("felixdittrich92-r34.onnx") + } + + pub fn db_resnet34_u8() -> Self { + Self::db2() + .with_model_file("https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet34_static_8_bit-027e2c7f.onnx") + } + + pub fn db_resnet50() -> Self { + Self::db2().with_model_file("felixdittrich92-r50.onnx") + } + + pub fn db_resnet50_u8() -> Self { + Self::db2() + .with_model_file("https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet50_static_8_bit-09a6104f.onnx") + } +} diff --git a/src/models/db.rs b/src/models/db/impl.rs similarity index 64% rename from src/models/db.rs rename to src/models/db/impl.rs index aefa620..de1e793 100644 --- a/src/models/db.rs +++ b/src/models/db/impl.rs @@ -1,36 +1,45 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{DynConf, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Mbr, Ops, Options, Polygon, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Debug, Builder)] pub struct DB { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, confs: DynConf, unclip_ratio: f32, binary_thresh: f32, min_width: f32, min_height: f32, + spec: String, + ts: Ts, + processor: Processor, } impl DB { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&960.into()).opt(), + engine.try_width().unwrap_or(&960.into()).opt(), + engine.ts.clone(), + engine.spec().to_owned(), ); - let confs = DynConf::new(&options.confs, 1); - let unclip_ratio = options.unclip_ratio; - let binary_thresh = 0.2; - let min_width = options.min_width.unwrap_or(0.); - let min_height = options.min_height.unwrap_or(0.); - engine.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let confs = DynConf::new(options.class_confs(), 1); + let binary_thresh = options.binary_thresh().unwrap_or(0.2); + let unclip_ratio = options.unclip_ratio().unwrap_or(1.5); + let min_width = options.min_width().unwrap_or(12.0); + let min_height = options.min_height().unwrap_or(5.0); Ok(Self { engine, @@ -42,29 +51,33 @@ impl DB { min_height, unclip_ratio, binary_thresh, + processor, + spec, + ts, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess(&mut self, xs: Xs) -> Result { let mut ys = Vec::new(); for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let mut y_bbox = Vec::new(); @@ -72,13 +85,10 @@ impl DB { let mut y_mbrs: Vec = Vec::new(); // input image - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; + let (image_height, image_width) = self.processor.image0s_size[idx]; // reshape - let h = luma.dim()[1]; - let w = luma.dim()[2]; - let (ratio, _, _) = Ops::scale_wh(image_width, image_height, w as f32, h as f32); + let ratio = self.processor.scale_factors_hw[idx][0]; let v = luma .into_owned() .into_raw_vec_and_offset() @@ -95,8 +105,8 @@ impl DB { let luma = Ops::resize_luma8_u8( &v, - self.width() as _, - self.height() as _, + self.width as _, + self.height as _, image_width as _, image_height as _, true, @@ -158,18 +168,7 @@ impl DB { .with_mbrs(&y_mbrs), ); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/db/mod.rs b/src/models/db/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/db/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/deim/README.md b/src/models/deim/README.md new file mode 100644 index 0000000..a225e3a --- /dev/null +++ b/src/models/deim/README.md @@ -0,0 +1,9 @@ +# DEIM: DETR with Improved Matching for Fast Convergence + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/ShihuaHuang95/DEIM) + +## Example + +Refer to the [example](../../../examples/deim) diff --git a/src/models/deim/config.rs b/src/models/deim/config.rs new file mode 100644 index 0000000..10c4a0a --- /dev/null +++ b/src/models/deim/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `DEIM` +impl crate::Options { + pub fn deim() -> Self { + Self::d_fine().with_model_name("deim") + } + + pub fn deim_dfine_s_coco() -> Self { + Self::deim().with_model_file("dfine-s-coco.onnx") + } + + pub fn deim_dfine_m_coco() -> Self { + Self::deim().with_model_file("dfine-m-coco.onnx") + } + + pub fn deim_dfine_l_coco() -> Self { + Self::deim().with_model_file("dfine-l-coco.onnx") + } + + pub fn deim_dfine_x_coco() -> Self { + Self::deim().with_model_file("dfine-x-coco.onnx") + } +} diff --git a/src/models/deim/mod.rs b/src/models/deim/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/deim/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/deit/README.md b/src/models/deit/README.md new file mode 100644 index 0000000..4af37ec --- /dev/null +++ b/src/models/deit/README.md @@ -0,0 +1,9 @@ +# DeiT: Data-Efficient architectures and training for Image classification + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/facebookresearch/deit) + +## Example + +Refer to the [example](../../../examples/deit) diff --git a/src/models/deit/config.rs b/src/models/deit/config.rs new file mode 100644 index 0000000..73734eb --- /dev/null +++ b/src/models/deit/config.rs @@ -0,0 +1,30 @@ +use crate::IMAGENET_NAMES_1K; + +/// Model configuration for `DeiT` +impl crate::Options { + pub fn deit() -> Self { + Self::default() + .with_model_name("deit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn deit_tiny_distill() -> Self { + Self::deit().with_model_file("t-distill.onnx") + } + + pub fn deit_small_distill() -> Self { + Self::deit().with_model_file("s-distill.onnx") + } + + pub fn deit_base_distill() -> Self { + Self::deit().with_model_file("b-distill.onnx") + } +} diff --git a/src/models/deit/mod.rs b/src/models/deit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/deit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/depth_anything.rs b/src/models/depth_anything.rs deleted file mode 100644 index 4573dfb..0000000 --- a/src/models/depth_anything.rs +++ /dev/null @@ -1,90 +0,0 @@ -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -#[derive(Debug)] -pub struct DepthAnything { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl DepthAnything { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let mut ys: Vec = Vec::new(); - for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let v = luma.into_owned().into_raw_vec_and_offset().0; - let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); - let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); - let v = v - .iter() - .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) - .collect::>(); - - let luma = Ops::resize_luma8_u8( - &v, - self.width() as _, - self.height() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/depth_anything/README.md b/src/models/depth_anything/README.md new file mode 100644 index 0000000..47dd60b --- /dev/null +++ b/src/models/depth_anything/README.md @@ -0,0 +1,12 @@ +# Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data + +## Official Repository + +The official repository can be found on: + +- [v1](https://github.com/LiheYoung/Depth-Anything) +- [v2](https://github.com/DepthAnything/Depth-Anything-V2) + +## Example + +Refer to the [example](../../../examples/depth-anything) diff --git a/src/models/depth_anything/config.rs b/src/models/depth_anything/config.rs new file mode 100644 index 0000000..6133876 --- /dev/null +++ b/src/models/depth_anything/config.rs @@ -0,0 +1,40 @@ +/// Model configuration for `DepthAnything` +impl crate::Options { + pub fn depth_anything() -> Self { + Self::default() + .with_model_name("depth-anything") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, (384, 518, 1024).into()) + .with_model_ixx(0, 3, (384, 518, 1024).into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_resize_filter("Lanczos3") + .with_normalize(true) + } + + pub fn depth_anything_s() -> Self { + Self::depth_anything().with_model_scale(crate::Scale::S) + } + + pub fn depth_anything_v1() -> Self { + Self::depth_anything().with_model_version(1.0.into()) + } + + pub fn depth_anything_v2() -> Self { + Self::depth_anything().with_model_version(2.0.into()) + } + + pub fn depth_anything_v1_small() -> Self { + Self::depth_anything_v1() + .with_model_scale(crate::Scale::S) + .with_model_file("v1-s.onnx") + } + + pub fn depth_anything_v2_small() -> Self { + Self::depth_anything_v2() + .with_model_scale(crate::Scale::S) + .with_model_file("v2-s.onnx") + } + // TODO +} diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs new file mode 100644 index 0000000..7a9b61c --- /dev/null +++ b/src/models/depth_anything/impl.rs @@ -0,0 +1,98 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Debug, Builder)] +pub struct DepthAnything { + engine: Engine, + height: usize, + width: usize, + batch: usize, + spec: String, + ts: Ts, + processor: Processor, +} + +impl DepthAnything { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&518.into()).opt(), + engine.try_width().unwrap_or(&518.into()).opt(), + engine.ts().clone(), + ); + + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + spec, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() { + // image size + let (h1, w1) = self.processor.image0s_size[idx]; + let v = luma.into_owned().into_raw_vec_and_offset().0; + let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); + let v = v + .iter() + .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) + .collect::>(); + + let luma = Ops::resize_luma8_u8( + &v, + self.width() as _, + self.height() as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/depth_anything/mod.rs b/src/models/depth_anything/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/depth_anything/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/depth_pro.rs b/src/models/depth_pro.rs deleted file mode 100644 index 26938f7..0000000 --- a/src/models/depth_pro.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -#[derive(Debug)] -pub struct DepthPro { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl DepthPro { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().clone(), - engine.height().clone(), - engine.width().clone(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]); - let predicted_depth = predicted_depth.mapv(|x| 1. / x); - - let mut ys: Vec = Vec::new(); - for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let v = luma.into_owned().into_raw_vec_and_offset().0; - let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); - let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); - let v = v - .iter() - .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) - .collect::>(); - - let luma = Ops::resize_luma8_u8( - &v, - self.width.opt() as _, - self.height.opt() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } -} diff --git a/src/models/depth_pro/README.md b/src/models/depth_pro/README.md new file mode 100644 index 0000000..d3da12a --- /dev/null +++ b/src/models/depth_pro/README.md @@ -0,0 +1,9 @@ +# Depth Pro: Sharp Monocular Metric Depth in Less Than a Second + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/apple/ml-depth-pro) + +## Example + +Refer to the [example](../../../examples/depth-pro) diff --git a/src/models/depth_pro/config.rs b/src/models/depth_pro/config.rs new file mode 100644 index 0000000..451682e --- /dev/null +++ b/src/models/depth_pro/config.rs @@ -0,0 +1,27 @@ +/// Model configuration for `DepthPro` +impl crate::Options { + pub fn depth_pro() -> Self { + Self::default() + .with_model_name("depth-pro") + .with_model_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 1536.into()) + .with_model_ixx(0, 3, 1536.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_normalize(true) + } + + // pub fn depth_pro_q4f16() -> Self { + // Self::depth_pro().with_model_file("q4f16.onnx") + // } + + // pub fn depth_pro_fp16() -> Self { + // Self::depth_pro().with_model_file("fp16.onnx") + // } + + // pub fn depth_pro_bnb4() -> Self { + // Self::depth_pro().with_model_file("bnb4.onnx") + // } +} diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs new file mode 100644 index 0000000..49518d3 --- /dev/null +++ b/src/models/depth_pro/impl.rs @@ -0,0 +1,99 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct DepthPro { + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, +} + +impl DepthPro { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + ts, + spec, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]); + let predicted_depth = predicted_depth.mapv(|x| 1. / x); + + let mut ys: Vec = Vec::new(); + for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() { + let (h1, w1) = self.processor.image0s_size[idx]; + let v = luma.into_owned().into_raw_vec_and_offset().0; + let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); + let v = v + .iter() + .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) + .collect::>(); + + let luma = Ops::resize_luma8_u8( + &v, + self.width as _, + self.height as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/depth_pro/mod.rs b/src/models/depth_pro/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/depth_pro/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/dinov2.rs b/src/models/dinov2.rs deleted file mode 100644 index 1facfa9..0000000 --- a/src/models/dinov2.rs +++ /dev/null @@ -1,161 +0,0 @@ -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -// use std::path::PathBuf; -// use usearch::ffi::{IndexOptions, MetricKind, ScalarKind}; - -#[derive(Debug)] -pub enum Model { - S, - B, -} - -#[derive(Debug)] -pub struct Dinov2 { - engine: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch: MinOptMax, - pub hidden_size: usize, -} - -impl Dinov2 { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.inputs_minoptmax()[0][0].to_owned(), - engine.inputs_minoptmax()[0][2].to_owned(), - engine.inputs_minoptmax()[0][3].to_owned(), - ); - let which = match options.onnx_path { - s if s.contains('b') => Model::B, - s if s.contains('s') => Model::S, - _ => todo!(), - }; - let hidden_size = match which { - Model::S => 384, - Model::B => 768, - }; - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - hidden_size, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) - } - - // pub fn build_index(&self, metric: Metric) -> Result { - // let metric = match metric { - // Metric::IP => MetricKind::IP, - // Metric::L2 => MetricKind::L2sq, - // Metric::Cos => MetricKind::Cos, - // }; - // let options = IndexOptions { - // metric, - // dimensions: self.hidden_size, - // quantization: ScalarKind::F16, - // ..Default::default() - // }; - // Ok(usearch::new_index(&options)?) - // } - - // pub fn query_from_folder( - // &mut self, - // qurey: &str, - // gallery: &str, - // metric: Metric, - // ) -> Result> { - // // load query - // let query = DataLoader::try_read(qurey)?; - // let query = self.run(&[query])?; - - // // build index & gallery - // let index = self.build_index(metric)?; - // let dl = DataLoader::default() - // .with_batch(self.batch.opt as usize) - // .load(gallery)?; - // let paths = dl.paths().to_owned(); - // index.reserve(paths.len())?; - - // // load feats - // for (idx, (x, _path)) in dl.enumerate() { - // let y = self.run(&x)?; - // index.add(idx as u64, &y.into_raw_vec())?; - // } - - // // output - // let matches = index.search(&query.into_raw_vec(), index.size())?; - // let mut results: Vec<(usize, f32, PathBuf)> = Vec::new(); - // matches - // .keys - // .into_iter() - // .zip(matches.distances) - // .for_each(|(k, score)| { - // results.push((k as usize, score, paths[k as usize].to_owned())); - // }); - - // Ok(results) - // } - - // pub fn query_from_vec( - // &mut self, - // qurey: &str, - // gallery: &[&str], - // metric: Metric, - // ) -> Result> { - // // load query - // let query = DataLoader::try_read(qurey)?; - // let query = self.run(&[query])?; - - // // build index & gallery - // let index = self.build_index(metric)?; - // index.reserve(gallery.len())?; - // let mut dl = DataLoader::default().with_batch(self.batch.opt as usize); - // gallery.iter().for_each(|x| { - // dl.load(x).unwrap(); - // }); - - // // load feats - // let paths = dl.paths().to_owned(); - // for (idx, (x, _path)) in dl.enumerate() { - // let y = self.run(&x)?; - // index.add(idx as u64, &y.into_raw_vec())?; - // } - - // // output - // let matches = index.search(&query.into_raw_vec(), index.size())?; - // let mut results: Vec<(usize, f32, PathBuf)> = Vec::new(); - // matches - // .keys - // .into_iter() - // .zip(matches.distances) - // .for_each(|(k, score)| { - // results.push((k as usize, score, paths[k as usize].to_owned())); - // }); - - // Ok(results) - // } -} diff --git a/src/models/dinov2/README.md b/src/models/dinov2/README.md new file mode 100644 index 0000000..9cb3c0d --- /dev/null +++ b/src/models/dinov2/README.md @@ -0,0 +1,9 @@ +# DINOv2: Learning Robust Visual Features without Supervision + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/facebookresearch/dinov2) + +## Example + +Refer to the [example](../../../examples/dinov2) diff --git a/src/models/dinov2/config.rs b/src/models/dinov2/config.rs new file mode 100644 index 0000000..abf7696 --- /dev/null +++ b/src/models/dinov2/config.rs @@ -0,0 +1,28 @@ +/// Model configuration for `DINOv2` +impl crate::Options { + pub fn dinov2() -> Self { + Self::default() + .with_model_name("dinov2") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_resize_filter("Lanczos3") + .with_normalize(true) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_image_mean(&[0.485, 0.456, 0.406]) + } + + pub fn dinov2_small() -> Self { + Self::dinov2() + .with_model_scale(crate::Scale::S) + .with_model_file("s.onnx") + } + + pub fn dinov2_base() -> Self { + Self::dinov2() + .with_model_scale(crate::Scale::B) + .with_model_file("b.onnx") + } +} diff --git a/src/models/dinov2/impl.rs b/src/models/dinov2/impl.rs new file mode 100644 index 0000000..de0897e --- /dev/null +++ b/src/models/dinov2/impl.rs @@ -0,0 +1,68 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{elapsed, Engine, Options, Processor, Scale, Ts, Xs, X}; + +#[derive(Builder, Debug)] +pub struct DINOv2 { + engine: Engine, + height: usize, + width: usize, + batch: usize, + dim: usize, + ts: Ts, + processor: Processor, +} + +impl DINOv2 { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&384.into()).opt(), + engine.try_width().unwrap_or(&384.into()).opt(), + engine.ts.clone(), + ); + let dim = match options.model_scale() { + Some(Scale::S) => 384, + Some(Scale::B) => 768, + Some(Scale::L) => 1024, + Some(Scale::G) => 1536, + Some(x) => anyhow::bail!("Unsupported scale: {:?}", x), + None => anyhow::bail!("No model scale specified"), + }; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + dim, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} diff --git a/src/models/dinov2/mod.rs b/src/models/dinov2/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/dinov2/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/fast/README.md b/src/models/fast/README.md new file mode 100644 index 0000000..6816a51 --- /dev/null +++ b/src/models/fast/README.md @@ -0,0 +1,9 @@ +# FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation + +## Official Repository + +The official repository can be found on: [FAST](https://github.com/czczup/FAST) + +## Example + +Refer to the [example](../../../examples/fast) diff --git a/src/models/fast/config.rs b/src/models/fast/config.rs new file mode 100644 index 0000000..6ec39c4 --- /dev/null +++ b/src/models/fast/config.rs @@ -0,0 +1,21 @@ +/// Model configuration for [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://github.com/czczup/FAST) +impl crate::Options { + pub fn fast() -> Self { + Self::db() + .with_model_name("fast") + .with_image_mean(&[0.798, 0.785, 0.772]) + .with_image_std(&[0.264, 0.2749, 0.287]) + } + + pub fn fast_tiny() -> Self { + Self::fast().with_model_file("felixdittrich92-rep-tiny.onnx") + } + + pub fn fast_small() -> Self { + Self::fast().with_model_file("felixdittrich92-rep-small.onnx") + } + + pub fn fast_base() -> Self { + Self::fast().with_model_file("felixdittrich92-rep-base.onnx") + } +} diff --git a/src/models/fast/mod.rs b/src/models/fast/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/fast/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/fastvit/README.md b/src/models/fastvit/README.md new file mode 100644 index 0000000..2b4f235 --- /dev/null +++ b/src/models/fastvit/README.md @@ -0,0 +1,9 @@ +# FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/apple/ml-fastvit) + +## Example + +Refer to the [example](../../../examples/fastvit) diff --git a/src/models/fastvit/config.rs b/src/models/fastvit/config.rs new file mode 100644 index 0000000..39eb817 --- /dev/null +++ b/src/models/fastvit/config.rs @@ -0,0 +1,74 @@ +use crate::IMAGENET_NAMES_1K; + +/// Model configuration for `FastViT` +impl crate::Options { + pub fn fastvit() -> Self { + Self::default() + .with_model_name("fastvit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn fastvit_t8() -> Self { + Self::fastvit().with_model_file("t8.onnx") + } + + pub fn fastvit_t8_distill() -> Self { + Self::fastvit().with_model_file("t8-distill.onnx") + } + + pub fn fastvit_t12() -> Self { + Self::fastvit().with_model_file("t12.onnx") + } + + pub fn fastvit_t12_distill() -> Self { + Self::fastvit().with_model_file("t12-distill.onnx") + } + + pub fn fastvit_s12() -> Self { + Self::fastvit().with_model_file("s12.onnx") + } + + pub fn fastvit_s12_distill() -> Self { + Self::fastvit().with_model_file("s12-distill.onnx") + } + + pub fn fastvit_sa12() -> Self { + Self::fastvit().with_model_file("sa12.onnx") + } + + pub fn fastvit_sa12_distill() -> Self { + Self::fastvit().with_model_file("sa12-distill.onnx") + } + + pub fn fastvit_sa24() -> Self { + Self::fastvit().with_model_file("sa24.onnx") + } + + pub fn fastvit_sa24_distill() -> Self { + Self::fastvit().with_model_file("sa24-distill.onnx") + } + + pub fn fastvit_sa36() -> Self { + Self::fastvit().with_model_file("sa36.onnx") + } + + pub fn fastvit_sa36_distill() -> Self { + Self::fastvit().with_model_file("sa36-distill.onnx") + } + + pub fn fastvit_ma36() -> Self { + Self::fastvit().with_model_file("ma36.onnx") + } + + pub fn fastvit_ma36_distill() -> Self { + Self::fastvit().with_model_file("ma36-distill.onnx") + } +} diff --git a/src/models/fastvit/mod.rs b/src/models/fastvit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/fastvit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/florence2.rs b/src/models/florence2.rs deleted file mode 100644 index 06d1674..0000000 --- a/src/models/florence2.rs +++ /dev/null @@ -1,459 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::{s, Axis}; -use rayon::prelude::*; -use std::collections::BTreeMap; -use tokenizers::Tokenizer; - -use crate::{ - build_progress_bar, Bbox, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, Polygon, - Quantizer, Task, Xs, X, Y, -}; - -#[derive(Debug)] -pub struct Florence2 { - pub vision_encoder: OrtEngine, - pub text_embed: OrtEngine, - pub encoder: OrtEngine, - pub decoder: OrtEngine, - pub decoder_merged: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - tokenizer: Tokenizer, - max_length: usize, - quantizer: Quantizer, -} - -impl Florence2 { - pub fn new( - options_vision_encoder: Options, - options_text_embed: Options, - options_encoder: Options, - options_decoder: Options, - options_decoder_merged: Options, - ) -> Result { - let mut vision_encoder = OrtEngine::new(&options_vision_encoder)?; - let mut text_embed = OrtEngine::new(&options_text_embed)?; - let mut encoder = OrtEngine::new(&options_encoder)?; - let mut decoder = OrtEngine::new(&options_decoder)?; - let mut decoder_merged = OrtEngine::new(&options_decoder_merged)?; - let (batch, height, width) = ( - vision_encoder.batch().to_owned(), - vision_encoder.height().to_owned(), - vision_encoder.width().to_owned(), - ); - let tokenizer = options_text_embed - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - let quantizer = Quantizer::default(); - let max_length = 1024; - - // dry run - vision_encoder.dry_run()?; - text_embed.dry_run()?; - encoder.dry_run()?; - decoder.dry_run()?; - decoder_merged.dry_run()?; - - Ok(Self { - vision_encoder, - text_embed, - encoder, - decoder, - decoder_merged, - height, - width, - batch, - tokenizer, - max_length, - quantizer, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.vision_encoder.run(Xs::from(xs_))?[0].to_owned(); - Ok(ys) - } - - pub fn run_with_tasks( - &mut self, - xs: &[DynamicImage], - tasks: &[Task], - ) -> Result>> { - let mut ys: BTreeMap> = BTreeMap::new(); - - // encode images - let image_embeddings = self.encode_images(xs)?; - - // note: the length of xs is not always equal to batch size - self.batch.update_opt(xs.len() as _); - - // build pb - let pb = build_progress_bar( - tasks.len() as u64, - " Working On", - None, - crate::PROGRESS_BAR_STYLE_CYAN_2, - )?; - - // tasks - for task in tasks.iter() { - pb.inc(1); - pb.set_message(format!("{:?}", task)); - - // construct prompt and encode - let input_ids = self - .encode_prompt(task)? - .insert_axis(0)? - .repeat(0, self.batch())?; - let text_embeddings = self.text_embed.run(Xs::from(input_ids))?[0].clone(); - - // run - let texts = self.run_batch(&image_embeddings, &text_embeddings)?; - - // tasks iteration - let ys_task = (0..self.batch()) - .into_par_iter() - .map(|batch| { - // image size - let image_width = xs[batch].width() as usize; - let image_height = xs[batch].height() as usize; - - // texts cleanup - let text = texts[batch] - .as_str() - .replace("", "") - .replace("", "") - .replace("", ""); - - // postprocess - let mut y = Y::default(); - if let Task::Caption(_) | Task::Ocr = task { - y = y.with_texts(&[text]); - } else { - let elems = Self::loc_parse(&text)?; - match task { - Task::RegionToCategory(..) | Task::RegionToDescription(..) => { - let text = elems[0][0].clone(); - y = y.with_texts(&[text]); - } - Task::ObjectDetection - | Task::OpenSetDetection(_) - | Task::DenseRegionCaption - | Task::CaptionToPhraseGrounding(_) => { - let y_bboxes: Vec = elems - .par_iter() - .enumerate() - .flat_map(|(i, elem)| { - Self::process_bboxes( - &elem[1..], - &self.quantizer, - image_width, - image_height, - Some((&elem[0], i)), - ) - }) - .collect(); - y = y.with_bboxes(&y_bboxes); - } - Task::RegionProposal => { - let y_bboxes: Vec = Self::process_bboxes( - &elems[0], - &self.quantizer, - image_width, - image_height, - None, - ); - y = y.with_bboxes(&y_bboxes); - } - Task::ReferringExpressionSegmentation(_) - | Task::RegionToSegmentation(..) => { - let points = Self::process_polygons( - &elems[0], - &self.quantizer, - image_width, - image_height, - ); - y = y.with_polygons(&[Polygon::default() - .with_points(&points) - .with_id(0)]); - } - Task::OcrWithRegion => { - let y_polygons: Vec = elems - .par_iter() - .enumerate() - .map(|(i, elem)| { - let points = Self::process_polygons( - &elem[1..], - &self.quantizer, - image_width, - image_height, - ); - Polygon::default() - .with_name(&elem[0]) - .with_points(&points) - .with_id(i as _) - }) - .collect(); - y = y.with_polygons(&y_polygons); - } - _ => anyhow::bail!("Unsupported Florence2 task."), - }; - } - Ok(y) - }) - .collect::>>()?; - - ys.insert(task.clone(), ys_task); - } - - // update pb - pb.set_prefix(" Completed"); - pb.set_message("Florence2 tasks"); - pb.set_style(indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH_2, - )?); - pb.finish(); - - Ok(ys) - } - - fn run_batch(&mut self, image_embeddings: &X, text_embeddings: &X) -> Result> { - // concate image_embeddings and prompt embeddings - let inputs_embeds = image_embeddings.clone().concatenate(text_embeddings, 1)?; - let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); - - // encoder - let last_hidden_state = self.encoder.run(Xs::from(vec![ - attention_mask.clone(), - inputs_embeds.clone(), - ]))?[0] - .clone(); - - // decoder - let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); - let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); - let mut decoder_outputs = self.decoder.run(Xs::from(vec![ - attention_mask.clone(), - last_hidden_state.clone(), - inputs_embeds, - ]))?; - - let encoder_k0 = decoder_outputs[3].clone(); - let encoder_v0 = decoder_outputs[4].clone(); - let encoder_k1 = decoder_outputs[7].clone(); - let encoder_v1 = decoder_outputs[8].clone(); - let encoder_k2 = decoder_outputs[11].clone(); - let encoder_v2 = decoder_outputs[12].clone(); - let encoder_k3 = decoder_outputs[15].clone(); - let encoder_v3 = decoder_outputs[16].clone(); - let encoder_k4 = decoder_outputs[19].clone(); - let encoder_v4 = decoder_outputs[20].clone(); - let encoder_k5 = decoder_outputs[23].clone(); - let encoder_v5 = decoder_outputs[24].clone(); - - let mut generated_tokens: Vec> = vec![vec![]; self.batch()]; - let mut finished = vec![false; self.batch()]; - - // save last batch tokens - let mut last_tokens: Vec = vec![0.; self.batch()]; - let mut logits_sampler = LogitsSampler::new(); - - // generate - for _ in 0..self.max_length { - let logits = &decoder_outputs["logits"]; - let decoder_k0 = &decoder_outputs[1]; - let decoder_v0 = &decoder_outputs[2]; - let decoder_k1 = &decoder_outputs[5]; - let decoder_v1 = &decoder_outputs[6]; - let decoder_k2 = &decoder_outputs[9]; - let decoder_v2 = &decoder_outputs[10]; - let decoder_k3 = &decoder_outputs[13]; - let decoder_v3 = &decoder_outputs[14]; - let decoder_k4 = &decoder_outputs[17]; - let decoder_v4 = &decoder_outputs[18]; - let decoder_k5 = &decoder_outputs[21]; - let decoder_v5 = &decoder_outputs[22]; - - // decode each token for each batch - for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { - if !finished[i] { - let token_id = logits_sampler.decode( - &logit - .slice(s![-1, ..]) - .into_owned() - .into_raw_vec_and_offset() - .0, - )?; // - generated_tokens[i].push(token_id); - - // update last_tokens - last_tokens[i] = token_id as f32; - - if token_id == 2 { - finished[i] = true; - } - } - } - - // all finished? - if finished.iter().all(|&x| x) { - break; - } - - // next input text embedding - let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; - - // decode - let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone(); - let use_cache = X::ones(&[1]); - decoder_outputs = self.decoder_merged.run(Xs::from(vec![ - attention_mask.clone(), - last_hidden_state.clone(), - inputs_embeds.clone(), - decoder_k0.clone(), - decoder_v0.clone(), - encoder_k0.clone(), - encoder_v0.clone(), - decoder_k1.clone(), - decoder_v1.clone(), - encoder_k1.clone(), - encoder_v1.clone(), - decoder_k2.clone(), - decoder_v2.clone(), - encoder_k2.clone(), - encoder_v2.clone(), - decoder_k3.clone(), - decoder_v3.clone(), - encoder_k3.clone(), - encoder_v3.clone(), - decoder_k4.clone(), - decoder_v4.clone(), - encoder_k4.clone(), - encoder_v4.clone(), - decoder_k5.clone(), - decoder_v5.clone(), - encoder_k5.clone(), - encoder_v5.clone(), - use_cache, - ]))?; - } - - // batch decode - let texts = match self.tokenizer.decode_batch( - &generated_tokens - .iter() - .map(|tokens| tokens.as_slice()) - .collect::>(), - false, - ) { - Err(err) => anyhow::bail!("{:?}", err), - Ok(xs) => xs, - }; - - Ok(texts) - } - - pub fn encode_prompt(&self, task: &Task) -> Result { - let prompt = task.prompt_for_florence2()?; - let encodings = match self.tokenizer.encode(prompt, true) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); - - Ok(X::from(ids)) - } - - fn process_polygons( - elems: &[String], - quantizer: &Quantizer, - image_width: usize, - image_height: usize, - ) -> Vec> { - elems - .par_chunks(2) - .map(|chunk| { - let coord: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); - quantizer.dequantize(&coord, (image_width, image_height)) - }) - .collect() - } - - fn process_bboxes( - elems: &[String], - quantizer: &Quantizer, - image_width: usize, - image_height: usize, - class_name: Option<(&str, usize)>, - ) -> Vec { - elems - .par_chunks(4) - .enumerate() - .map(|(i, chunk)| { - let bbox: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); - let dequantized_bbox = quantizer.dequantize(&bbox, (image_width, image_height)); - - let mut bbox = Bbox::default().with_xyxy( - dequantized_bbox[0].max(0.0f32).min(image_width as f32), - dequantized_bbox[1].max(0.0f32).min(image_height as f32), - dequantized_bbox[2], - dequantized_bbox[3], - ); - if let Some((class_name, i)) = class_name { - bbox = bbox.with_name(class_name).with_id(i as _); - } else { - bbox = bbox.with_id(i as _); - } - - bbox - }) - .collect() - } - - fn loc_parse(hay: &str) -> Result>> { - let pattern = r"(?i)(\d+)>)|(?[^<]+)"; - let re = regex::Regex::new(pattern)?; - let mut ys: Vec> = Vec::new(); - let mut y = Vec::new(); - - for cap in re.captures_iter(hay) { - if let Some(loc) = cap.name("coord") { - y.push(loc.as_str().to_string()); - } else if let Some(text) = cap.name("name") { - if !text.as_str().is_empty() { - if !y.is_empty() { - ys.push(y); - y = Vec::new(); - } - y.push(text.as_str().to_string()); - } - } - } - if !y.is_empty() { - ys.push(y); - } - Ok(ys) - } - - pub fn batch(&self) -> usize { - self.batch.opt() - } -} diff --git a/src/models/florence2/README.md b/src/models/florence2/README.md new file mode 100644 index 0000000..7930baa --- /dev/null +++ b/src/models/florence2/README.md @@ -0,0 +1,9 @@ +# Florence-2: Advancing a Unified Representation for a Variety of Vision Tasks + +## Official Repository + +The official repository can be found on: [Hugging Face](https://huggingface.co/microsoft/Florence-2-base) + +## Example + +Refer to the [example](../../../examples/florence2) diff --git a/src/models/florence2/config.rs b/src/models/florence2/config.rs new file mode 100644 index 0000000..8ef74ac --- /dev/null +++ b/src/models/florence2/config.rs @@ -0,0 +1,59 @@ +/// Model configuration for `Florence2` +impl crate::Options { + pub fn florence2() -> Self { + Self::default() + .with_model_name("florence2") + .with_batch_size(1) + } + + pub fn florence2_visual() -> Self { + Self::florence2() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 2, 768.into()) + .with_model_ixx(0, 3, 768.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn florence2_textual() -> Self { + Self::florence2().with_model_kind(crate::Kind::Language) + } + + pub fn florence2_visual_base() -> Self { + Self::florence2_visual().with_model_scale(crate::Scale::B) + } + + pub fn florence2_textual_base() -> Self { + Self::florence2_textual().with_model_scale(crate::Scale::B) + } + + pub fn florence2_visual_large() -> Self { + Self::florence2_visual().with_model_scale(crate::Scale::L) + } + + pub fn florence2_textual_large() -> Self { + Self::florence2_textual().with_model_scale(crate::Scale::L) + } + + pub fn florence2_visual_encoder_base() -> Self { + Self::florence2_visual_base().with_model_file("base-vision-encoder.onnx") + } + + pub fn florence2_textual_embed_base() -> Self { + Self::florence2_textual_base().with_model_file("base-embed-tokens.onnx") + } + + pub fn florence2_texual_encoder_base() -> Self { + Self::florence2_textual_base().with_model_file("base-encoder.onnx") + } + + pub fn florence2_texual_decoder_base() -> Self { + Self::florence2_textual_base().with_model_file("base-decoder.onnx") + } + + pub fn florence2_texual_decoder_merged_base() -> Self { + Self::florence2_textual_base().with_model_file("base-decoder-merged.onnx") + } +} diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs new file mode 100644 index 0000000..b4094e2 --- /dev/null +++ b/src/models/florence2/impl.rs @@ -0,0 +1,427 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual, Quantizer}, + Bbox, LogitsSampler, Options, Polygon, Scale, Task, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Builder)] +pub struct Florence2 { + pub vision_encoder: BaseModelVisual, + pub text_embed: BaseModelTextual, + pub encoder: BaseModelTextual, + pub decoder: BaseModelTextual, + pub decoder_merged: BaseModelTextual, + ts: Ts, + quantizer: Quantizer, + max_length: usize, + eos_token_id: u32, + decoder_start_token_id: u32, + n_kvs: usize, +} + +impl Florence2 { + pub fn new( + options_vision_encoder: Options, + options_text_embed: Options, + options_encoder: Options, + options_decoder: Options, + options_decoder_merged: Options, + ) -> Result { + let vision_encoder = BaseModelVisual::new(options_vision_encoder)?; + let text_embed = BaseModelTextual::new(options_text_embed)?; + let encoder = BaseModelTextual::new(options_encoder)?; + let decoder = BaseModelTextual::new(options_decoder)?; + let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; + let quantizer = Quantizer::default(); + let ts = Ts::merge(&[ + vision_encoder.engine().ts(), + text_embed.engine().ts(), + encoder.engine().ts(), + decoder.engine().ts(), + decoder_merged.engine().ts(), + ]); + let max_length = 1024; + let eos_token_id = 2; + let decoder_start_token_id = 2; + let n_kvs = match decoder.scale() { + Some(Scale::B) => 6, + Some(Scale::L) => 12, + _ => unimplemented!(), + }; + + Ok(Self { + vision_encoder, + text_embed, + encoder, + decoder, + decoder_merged, + max_length, + quantizer, + ts, + eos_token_id, + decoder_start_token_id, + n_kvs, + }) + } + + fn process_task(task: &Task, image_height: usize, image_width: usize) -> Task { + // region-related tasks + match task { + Task::RegionToSegmentation(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToSegmentation(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + Task::RegionToCategory(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToCategory(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + Task::RegionToDescription(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToDescription(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + _ => *task, + } + } + + fn encode_text(&mut self, task: &Task, images: &[DynamicImage]) -> Result { + let xs = images + .par_iter() + .map(|im| { + let text = Self::process_task(task, im.height() as _, im.width() as _) + .prompt_for_florence2()?; + let ids = self.text_embed.processor().encode_text_ids(&text, true)?; + X::from(ids).insert_axis(0) + }) + .collect::, _>>()?; + let x = X::concat(&xs, 0)?; + let xs = self.text_embed.inference(x.into())?; + let x = xs[0].to_owned(); + + Ok(x) + } + + pub fn forward(&mut self, xs_visual: &[DynamicImage], x_textual: &Task) -> Result { + let visual_embeddings = elapsed!("visual-encode", self.ts, { + self.vision_encoder.encode(xs_visual)? + }); + + let textual_embedding = elapsed!("textual-encode", self.ts, { + self.encode_text(x_textual, xs_visual)? + }); + + let generated = elapsed!("generate-then-decode", self.ts, { + self.generate_then_decode(&visual_embeddings, &textual_embedding)? + }); + + let ys = elapsed!("postprocess", self.ts, { + self.postprocess(&generated, xs_visual, x_textual)? + }); + + Ok(ys) + } + + // decode or postprocess, batch images and one text + fn generate_then_decode( + &mut self, + visual_embeddings: &X, + textual_embedding: &X, + ) -> Result> { + // concate image embeddings and prompt embeddings + let inputs_embeds = visual_embeddings + .clone() + .concatenate(textual_embedding, 1)?; + let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); + + // encoder + let last_hidden_state = self.encoder.inference(Xs::from(vec![ + attention_mask.clone(), + inputs_embeds.clone(), + ]))?[0] + .clone(); + + // decoder + let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); + let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); + let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds, + ]))?; + + // encoder kvs + let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // token ids + let mut token_ids: Vec> = vec![vec![]; self.batch()]; + let mut finished = vec![false; self.batch()]; + let mut last_tokens: Vec = vec![0.; self.batch()]; + let logits_sampler = LogitsSampler::new(); + + // generate + for _ in 0..self.max_length { + let logits = &decoder_outputs[0]; + let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // decode each token for each batch + // let (finished, last_tokens) = self.decoder_merged.processor().par_generate( + // logits, + // &mut token_ids, + // self.eos_token_id, + // )?; + + // if finished { + // break; + // } + + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + if token_id == self.eos_token_id { + finished[i] = true; + } else { + token_ids[i].push(token_id); + } + // update + last_tokens[i] = token_id as f32; + } + } + + // all finished? + if finished.iter().all(|&x| x) { + break; + } + + // decode + let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; + let inputs_embeds = &self.text_embed.inference(Xs::from(next_tokens))?[0].clone(); + let use_cache = X::ones(&[1]); + let mut xs = vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds.clone(), + ]; + for i in 0..self.n_kvs { + xs.push(decoder_kvs[i * 2].clone()); + xs.push(decoder_kvs[i * 2 + 1].clone()); + xs.push(encoder_kvs[i * 2].clone()); + xs.push(encoder_kvs[i * 2 + 1].clone()); + } + xs.push(use_cache); + decoder_outputs = self.decoder_merged.inference(xs.into())?; + } + + // batch decode + let texts = self + .text_embed + .processor() + .decode_tokens_batch(&token_ids, false)?; + + Ok(texts) + } + + fn postprocess( + &mut self, + generated_text: &[String], + xs_visual: &[DynamicImage], + x_textual: &Task, + ) -> Result { + let mut ys = Vec::new(); + let ys_task = (0..self.batch()) + .into_par_iter() + .map(|batch| { + // image size + let image_width = xs_visual[batch].width() as usize; + let image_height = xs_visual[batch].height() as usize; + + // texts cleanup + let text = generated_text[batch] + .as_str() + .replace("", "") + .replace("", "") + .replace("", ""); + + // postprocess + let mut y = Y::default(); + if let Task::Caption(_) | Task::Ocr = x_textual { + y = y.with_texts(&[text.into()]); + } else { + let elems = Self::loc_parse(&text)?; + match x_textual { + Task::RegionToCategory(..) | Task::RegionToDescription(..) => { + let text = elems[0][0].clone(); + y = y.with_texts(&[text.into()]); + } + Task::ObjectDetection + | Task::OpenSetDetection(_) + | Task::DenseRegionCaption + | Task::CaptionToPhraseGrounding(_) => { + let y_bboxes: Vec = elems + .par_iter() + .enumerate() + .flat_map(|(i, elem)| { + Self::process_bboxes( + &elem[1..], + &self.quantizer, + image_width, + image_height, + Some((&elem[0], i)), + ) + }) + .collect(); + y = y.with_bboxes(&y_bboxes); + } + Task::RegionProposal => { + let y_bboxes: Vec = Self::process_bboxes( + &elems[0], + &self.quantizer, + image_width, + image_height, + None, + ); + y = y.with_bboxes(&y_bboxes); + } + Task::ReferringExpressionSegmentation(_) + | Task::RegionToSegmentation(..) => { + let points = Self::process_polygons( + &elems[0], + &self.quantizer, + image_width, + image_height, + ); + y = y.with_polygons(&[Polygon::default() + .with_points(&points) + .with_id(0)]); + } + Task::OcrWithRegion => { + let y_polygons: Vec = elems + .par_iter() + .enumerate() + .map(|(i, elem)| { + let points = Self::process_polygons( + &elem[1..], + &self.quantizer, + image_width, + image_height, + ); + Polygon::default() + .with_name(&elem[0]) + .with_points(&points) + .with_id(i as _) + }) + .collect(); + y = y.with_polygons(&y_polygons); + } + _ => anyhow::bail!("Unsupported Florence2 task."), + }; + } + Ok(y) + }) + .collect::>>()?; + + ys.extend_from_slice(&ys_task); + + Ok(ys.into()) + } + + fn process_polygons( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + ) -> Vec> { + elems + .par_chunks(2) + .map(|chunk| { + let coord: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + quantizer.dequantize(&coord, (image_width, image_height)) + }) + .collect() + } + + fn process_bboxes( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + class_name: Option<(&str, usize)>, + ) -> Vec { + elems + .par_chunks(4) + .enumerate() + .map(|(i, chunk)| { + let bbox: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + let dequantized_bbox = quantizer.dequantize(&bbox, (image_width, image_height)); + + let mut bbox = Bbox::default().with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1].max(0.0f32).min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ); + if let Some((class_name, i)) = class_name { + bbox = bbox.with_name(class_name).with_id(i as _); + } else { + bbox = bbox.with_id(i as _); + } + + bbox + }) + .collect() + } + + fn loc_parse(hay: &str) -> Result>> { + let pattern = r"(?i)(\d+)>)|(?[^<]+)"; + let re = regex::Regex::new(pattern)?; + let mut ys: Vec> = Vec::new(); + let mut y = Vec::new(); + + for cap in re.captures_iter(hay) { + if let Some(loc) = cap.name("coord") { + y.push(loc.as_str().to_string()); + } else if let Some(text) = cap.name("name") { + if !text.as_str().is_empty() { + if !y.is_empty() { + ys.push(y); + y = Vec::new(); + } + y.push(text.as_str().to_string()); + } + } + } + if !y.is_empty() { + ys.push(y); + } + Ok(ys) + } + + pub fn batch(&self) -> usize { + self.vision_encoder.batch() as _ + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/florence2/mod.rs b/src/models/florence2/mod.rs new file mode 100644 index 0000000..5405ab1 --- /dev/null +++ b/src/models/florence2/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod r#impl; +mod quantizer; + +pub use quantizer::Quantizer; +pub use r#impl::*; diff --git a/src/utils/quantizer.rs b/src/models/florence2/quantizer.rs similarity index 76% rename from src/utils/quantizer.rs rename to src/models/florence2/quantizer.rs index 1a3a6ac..615247d 100644 --- a/src/utils/quantizer.rs +++ b/src/models/florence2/quantizer.rs @@ -22,7 +22,7 @@ impl Quantizer { ((val as f64 + 0.5) * bin_size) as f32 } - fn quantize_internal(&self, input: &[f32], size: (usize, usize)) -> Vec { + fn quantize_internal(&self, input: &[usize], size: (usize, usize)) -> Vec { let (bins_w, bins_h) = self.bins; let (size_w, size_h) = size; @@ -31,14 +31,14 @@ impl Quantizer { match input.len() { 4 => vec![ - self.quantize_value(input[0], size_per_bin_w, bins_w), - self.quantize_value(input[1], size_per_bin_h, bins_h), - self.quantize_value(input[2], size_per_bin_w, bins_w), - self.quantize_value(input[3], size_per_bin_h, bins_h), + self.quantize_value(input[0] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[1] as f32, size_per_bin_h, bins_h), + self.quantize_value(input[2] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[3] as f32, size_per_bin_h, bins_h), ], 2 => vec![ - self.quantize_value(input[0], size_per_bin_w, bins_w), - self.quantize_value(input[1], size_per_bin_h, bins_h), + self.quantize_value(input[0] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[1] as f32, size_per_bin_h, bins_h), ], _ => panic!( "Error: Unsupported input length: {} in Quantizer.", @@ -72,7 +72,7 @@ impl Quantizer { } } - pub fn quantize(&self, input: &[f32], size: (usize, usize)) -> Vec { + pub fn quantize(&self, input: &[usize], size: (usize, usize)) -> Vec { self.quantize_internal(input, size) } diff --git a/src/models/grounding_dino.rs b/src/models/grounding_dino.rs deleted file mode 100644 index 8700c91..0000000 --- a/src/models/grounding_dino.rs +++ /dev/null @@ -1,245 +0,0 @@ -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::{s, Array, Axis}; -use rayon::prelude::*; -use tokenizers::{Encoding, Tokenizer}; - -#[derive(Debug)] -pub struct GroundingDINO { - pub engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - tokenizer: Tokenizer, - pub context_length: usize, - confs_visual: DynConf, - confs_textual: DynConf, -} - -impl GroundingDINO { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.inputs_minoptmax()[0][0].to_owned(), - engine.inputs_minoptmax()[0][2].to_owned(), - engine.inputs_minoptmax()[0][3].to_owned(), - ); - let context_length = options.context_length.unwrap_or(256); - // let special_tokens = ["[CLS]", "[SEP]", ".", "?"]; - let tokenizer = options - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - let confs_visual = DynConf::new(&options.confs, 1); - let confs_textual = DynConf::new(&options.confs, 1); - - engine.dry_run()?; - - Ok(Self { - engine, - batch, - height, - width, - tokenizer, - context_length, - confs_visual, - confs_textual, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage], texts: &[&str]) -> Result> { - // image embeddings - let image_embeddings = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - - // encoding - let text = Self::parse_texts(texts); - let encoding = match self.tokenizer.encode(text, true) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let tokens = encoding.get_tokens(); - - // input_ids - let input_ids = X::from( - encoding - .get_ids() - .iter() - .map(|&x| x as f32) - .collect::>(), - ) - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // token_type_ids - let token_type_ids = X::zeros(&[self.batch() as usize, tokens.len()]); - - // attention_mask - let attention_mask = X::ones(&[self.batch() as usize, tokens.len()]); - - // position_ids - let position_ids = X::from( - encoding - .get_tokens() - .iter() - .map(|x| if x == "." { 1. } else { 0. }) - .collect::>(), - ) - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // text_self_attention_masks - let text_self_attention_masks = Self::gen_text_self_attention_masks(&encoding)? - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // run - let ys = self.engine.run(Xs::from(vec![ - image_embeddings, - input_ids, - attention_mask, - position_ids, - token_type_ids, - text_self_attention_masks, - ]))?; - - // post-process - self.postprocess(ys, xs, tokens) - } - - fn postprocess(&self, xs: Xs, xs0: &[DynamicImage], tokens: &[String]) -> Result> { - let ys: Vec = xs["logits"] - .axis_iter(Axis(0)) - .into_par_iter() - .enumerate() - .filter_map(|(idx, logits)| { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); - - let y_bboxes: Vec = logits - .axis_iter(Axis(0)) - .into_par_iter() - .enumerate() - .filter_map(|(i, clss)| { - let (class_id, &conf) = clss - .mapv(|x| 1. / ((-x).exp() + 1.)) - .iter() - .enumerate() - .max_by(|a, b| a.1.total_cmp(b.1))?; - - if conf < self.conf_visual() { - return None; - } - - let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio); - let cx = bbox[0] * self.width() as f32; - let cy = bbox[1] * self.height() as f32; - let w = bbox[2] * self.width() as f32; - let h = bbox[3] * self.height() as f32; - let x = cx - w / 2.; - let y = cy - h / 2.; - let x = x.max(0.0).min(image_width); - let y = y.max(0.0).min(image_height); - - Some( - Bbox::default() - .with_xywh(x, y, w, h) - .with_id(class_id as _) - .with_name(&tokens[class_id]) - .with_confidence(conf), - ) - }) - .collect(); - - if !y_bboxes.is_empty() { - Some(Y::default().with_bboxes(&y_bboxes)) - } else { - None - } - }) - .collect(); - Ok(ys) - } - - fn parse_texts(texts: &[&str]) -> String { - let mut y = String::new(); - for text in texts.iter() { - if !text.is_empty() { - y.push_str(&format!("{} . ", text)); - } - } - y - } - - fn gen_text_self_attention_masks(encoding: &Encoding) -> Result { - let mut vs = encoding - .get_tokens() - .iter() - .map(|x| if x == "." { 1. } else { 0. }) - .collect::>(); - - let n = vs.len(); - vs[0] = 1.; - vs[n - 1] = 1.; - let mut ys = Array::zeros((n, n)).into_dyn(); - let mut i_last = -1; - for (i, &v) in vs.iter().enumerate() { - if v == 0. { - if i_last == -1 { - i_last = i as isize; - } else { - i_last = -1; - } - } else if v == 1. { - if i_last == -1 { - ys.slice_mut(s![i, i]).fill(1.); - } else { - ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1]) - .fill(1.); - } - i_last = -1; - } else { - continue; - } - } - Ok(X::from(ys)) - } - - pub fn conf_visual(&self) -> f32 { - self.confs_visual[0] - } - - pub fn conf_textual(&self) -> f32 { - self.confs_textual[0] - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/grounding_dino/README.md b/src/models/grounding_dino/README.md new file mode 100644 index 0000000..e19d6ea --- /dev/null +++ b/src/models/grounding_dino/README.md @@ -0,0 +1,9 @@ +# Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/IDEA-Research/GroundingDINO) + +## Example + +Refer to the [example](../../../examples/grounding-dino) diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs new file mode 100644 index 0000000..4c54ee0 --- /dev/null +++ b/src/models/grounding_dino/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `GroundingDino` +impl crate::Options { + pub fn grounding_dino() -> Self { + Self::default() + .with_model_name("grounding-dino") + .with_model_kind(crate::Kind::VisionLanguage) + .with_model_ixx(0, 0, 1.into()) // TODO: current onnx model does not support bs > 1 + .with_model_ixx(0, 2, 800.into()) // TODO: matters + .with_model_ixx(0, 3, 1200.into()) // TODO: matters + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_class_confs(&[0.4]) + .with_text_confs(&[0.3]) + } + + pub fn grounding_dino_tiny() -> Self { + Self::grounding_dino().with_model_file("swint-ogc.onnx") + } +} diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs new file mode 100644 index 0000000..46e8fc8 --- /dev/null +++ b/src/models/grounding_dino/impl.rs @@ -0,0 +1,223 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Array, Axis}; +use rayon::prelude::*; + +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; + +#[derive(Builder, Debug)] +pub struct GroundingDINO { + pub engine: Engine, + height: usize, + width: usize, + batch: usize, + confs_visual: DynConf, + confs_textual: DynConf, + class_names: Vec, + tokens: Vec, + token_ids: Vec, + ts: Ts, + processor: Processor, + spec: String, +} + +impl GroundingDINO { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&800.into()).opt(), + engine.try_width().unwrap_or(&1200.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let confs_visual = DynConf::new(options.class_confs(), 1); + let confs_textual = DynConf::new(options.text_confs(), 1); + + let class_names = Self::parse_texts( + &options + .text_names + .expect("No class names specified!") + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + let token_ids = processor.encode_text_ids(&class_names, true)?; + let tokens = processor.encode_text_tokens(&class_names, true)?; + let class_names = tokens.clone(); + + Ok(Self { + engine, + batch, + height, + width, + confs_visual, + confs_textual, + class_names, + token_ids, + tokens, + ts, + processor, + spec, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + // encode images + let image_embeddings = self.processor.process_images(xs)?; + + // encode texts + let tokens_f32 = self + .tokens + .iter() + .map(|x| if x == "." { 1. } else { 0. }) + .collect::>(); + + // input_ids + let input_ids = X::from(self.token_ids.clone()) + .insert_axis(0)? + .repeat(0, self.batch)?; + + // token_type_ids + let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]); + + // attention_mask + let attention_mask = X::ones(&[self.batch, tokens_f32.len()]); + + // text_self_attention_masks + let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)? + .insert_axis(0)? + .repeat(0, self.batch)?; + + // position_ids + let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?; + + // inputs + let xs = Xs::from(vec![ + image_embeddings, + input_ids, + attention_mask, + position_ids, + token_type_ids, + text_self_attention_masks, + ]); + + Ok(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&self, xs: Xs) -> Result { + let ys: Vec = xs["logits"] + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(idx, logits)| { + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; + + let y_bboxes: Vec = logits + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(i, clss)| { + let (class_id, &conf) = clss + .mapv(|x| 1. / ((-x).exp() + 1.)) + .iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1))?; + + if conf < self.confs_visual[0] { + return None; + } + + let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio); + let cx = bbox[0] * self.width as f32; + let cy = bbox[1] * self.height as f32; + let w = bbox[2] * self.width as f32; + let h = bbox[3] * self.height as f32; + let x = cx - w / 2.; + let y = cy - h / 2.; + let x = x.max(0.0).min(image_width as _); + let y = y.max(0.0).min(image_height as _); + + Some( + Bbox::default() + .with_xywh(x, y, w, h) + .with_id(class_id as _) + .with_name(&self.class_names[class_id]) + .with_confidence(conf), + ) + }) + .collect(); + + if !y_bboxes.is_empty() { + Some(Y::default().with_bboxes(&y_bboxes)) + } else { + None + } + }) + .collect(); + + Ok(ys.into()) + } + + fn parse_texts(texts: &[&str]) -> String { + let mut y = String::new(); + for text in texts.iter() { + if !text.is_empty() { + y.push_str(&format!("{} . ", text)); + } + } + y + } + + fn gen_text_self_attention_masks(tokens: &[f32]) -> Result { + let mut vs = tokens.to_vec(); + let n = vs.len(); + vs[0] = 1.; + vs[n - 1] = 1.; + let mut ys = Array::zeros((n, n)).into_dyn(); + let mut i_last = -1; + for (i, &v) in vs.iter().enumerate() { + if v == 0. { + if i_last == -1 { + i_last = i as isize; + } else { + i_last = -1; + } + } else if v == 1. { + if i_last == -1 { + ys.slice_mut(s![i, i]).fill(1.); + } else { + ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1]) + .fill(1.); + } + i_last = -1; + } else { + continue; + } + } + Ok(X::from(ys)) + } +} diff --git a/src/models/grounding_dino/mod.rs b/src/models/grounding_dino/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/grounding_dino/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/linknet/README.md b/src/models/linknet/README.md new file mode 100644 index 0000000..2a19412 --- /dev/null +++ b/src/models/linknet/README.md @@ -0,0 +1,11 @@ +# LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation + +Pixel-wise semantic segmentation for visual scene understanding not only needs to be accurate, but also efficient in order to find any use in real-time application. Existing algorithms even though are accurate but they do not focus on utilizing the parameters of neural network efficiently. As a result they are huge in terms of parameters and number of operations; hence slow too. In this paper, we propose a novel deep neural network architecture which allows it to learn without any significant increase in number of parameters. Our network uses only 11.5 million parameters and 21.2 GFLOPs for processing an image of resolution 3x640x360. It gives state-of-the-art performance on CamVid and comparable results on Cityscapes dataset. We also compare our networks processing time on NVIDIA GPU and embedded system device with existing state-of-the-art architectures for different image resolutions. + +## Official Repository + +The official paper can be found on: [LinkNet](https://arxiv.org/abs/1707.03718) + +## Example + +Refer to the [example](../../../examples/linknet) diff --git a/src/models/linknet/config.rs b/src/models/linknet/config.rs new file mode 100644 index 0000000..c1c666a --- /dev/null +++ b/src/models/linknet/config.rs @@ -0,0 +1,21 @@ +/// Model configuration for [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/abs/1707.03718) +impl crate::Options { + pub fn linknet() -> Self { + Self::fast() + .with_model_name("linknet") + .with_image_mean(&[0.798, 0.785, 0.772]) + .with_image_std(&[0.264, 0.2749, 0.287]) + } + + pub fn linknet_r18() -> Self { + Self::linknet().with_model_file("felixdittrich92-r18.onnx") + } + + pub fn linknet_r34() -> Self { + Self::linknet().with_model_file("felixdittrich92-r34.onnx") + } + + pub fn linknet_r50() -> Self { + Self::linknet().with_model_file("felixdittrich92-r50.onnx") + } +} diff --git a/src/models/linknet/mod.rs b/src/models/linknet/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/linknet/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/mobileone/README.md b/src/models/mobileone/README.md new file mode 100644 index 0000000..796057a --- /dev/null +++ b/src/models/mobileone/README.md @@ -0,0 +1,9 @@ +# MobileOne: An Improved One millisecond Mobile Backbone + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/apple/ml-mobileone) + +## Example + +Refer to the [example](../../../examples/mobileone) diff --git a/src/models/mobileone/config.rs b/src/models/mobileone/config.rs new file mode 100644 index 0000000..56c0a8e --- /dev/null +++ b/src/models/mobileone/config.rs @@ -0,0 +1,50 @@ +use crate::IMAGENET_NAMES_1K; + +/// Model configuration for `MobileOne` +impl crate::Options { + pub fn mobileone() -> Self { + Self::default() + .with_model_name("mobileone") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_apply_softmax(true) + .with_normalize(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn mobileone_s0() -> Self { + Self::mobileone().with_model_file("s0.onnx") + } + + pub fn mobileone_s1() -> Self { + Self::mobileone().with_model_file("s1.onnx") + } + + pub fn mobileone_s2() -> Self { + Self::mobileone().with_model_file("s2.onnx") + } + + pub fn mobileone_s3() -> Self { + Self::mobileone().with_model_file("s3.onnx") + } + + pub fn mobileone_s4_224x224() -> Self { + Self::mobileone().with_model_file("s4-224x224.onnx") + } + + pub fn mobileone_s4_256x256() -> Self { + Self::mobileone().with_model_file("s4-256x256.onnx") + } + + pub fn mobileone_s4_384x384() -> Self { + Self::mobileone().with_model_file("s4-384x384.onnx") + } + + pub fn mobileone_s4_512x512() -> Self { + Self::mobileone().with_model_file("s4-512x512.onnx") + } +} diff --git a/src/models/mobileone/mod.rs b/src/models/mobileone/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/mobileone/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/mod.rs b/src/models/mod.rs index 6df4c8c..79db3c5 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,35 +1,50 @@ -//! Models provided: [`Blip`], [`Clip`], [`YOLO`], [`DepthAnything`], ... - +mod beit; mod blip; mod clip; +mod convnext; +mod d_fine; mod db; +mod deim; +mod deit; mod depth_anything; mod depth_pro; mod dinov2; +mod fast; +mod fastvit; mod florence2; mod grounding_dino; +mod linknet; +mod mobileone; mod modnet; +mod picodet; +mod pipeline; +mod rtdetr; mod rtmo; mod sam; mod sapiens; +mod slanet; mod svtr; +mod trocr; mod yolo; -mod yolo_; mod yolop; -pub use blip::Blip; -pub use clip::Clip; -pub use db::DB; -pub use depth_anything::DepthAnything; -pub use depth_pro::DepthPro; -pub use dinov2::Dinov2; -pub use florence2::Florence2; -pub use grounding_dino::GroundingDINO; -pub use modnet::MODNet; -pub use rtmo::RTMO; -pub use sam::{SamKind, SamPrompt, SAM}; -pub use sapiens::{Sapiens, SapiensTask}; -pub use svtr::SVTR; -pub use yolo::YOLO; -pub use yolo_::*; -pub use yolop::YOLOPv2; +pub use blip::*; +pub use clip::*; +pub use db::*; +pub use depth_anything::*; +pub use depth_pro::*; +pub use dinov2::*; +pub use florence2::*; +pub use grounding_dino::*; +pub use modnet::*; +pub use picodet::*; +pub use pipeline::*; +pub use rtdetr::*; +pub use rtmo::*; +pub use sam::*; +pub use sapiens::*; +pub use slanet::*; +pub use svtr::*; +pub use trocr::*; +pub use yolo::*; +pub use yolop::*; diff --git a/src/models/modnet.rs b/src/models/modnet.rs deleted file mode 100644 index 4f87cbd..0000000 --- a/src/models/modnet.rs +++ /dev/null @@ -1,84 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct MODNet { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl MODNet { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let mut ys: Vec = Vec::new(); - for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let luma = luma.mapv(|x| (x * 255.0) as u8); - let luma = Ops::resize_luma8_u8( - &luma.into_raw_vec_and_offset().0, - self.width() as _, - self.height() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/modnet/README.md b/src/models/modnet/README.md new file mode 100644 index 0000000..19939cb --- /dev/null +++ b/src/models/modnet/README.md @@ -0,0 +1,9 @@ +# MODNet: Trimap-Free Portrait Matting in Real Time + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/ZHKKKe/MODNet) + +## Example + +Refer to the [example](../../../examples/modnet) diff --git a/src/models/modnet/config.rs b/src/models/modnet/config.rs new file mode 100644 index 0000000..05174d2 --- /dev/null +++ b/src/models/modnet/config.rs @@ -0,0 +1,17 @@ +/// Model configuration for `MODNet` +impl crate::Options { + pub fn modnet() -> Self { + Self::default() + .with_model_name("modnet") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, (416, 512, 800).into()) + .with_model_ixx(0, 3, (416, 512, 800).into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_normalize(true) + } + + pub fn modnet_photographic() -> Self { + Self::modnet().with_model_file("photographic-portrait-matting.onnx") + } +} diff --git a/src/models/modnet/impl.rs b/src/models/modnet/impl.rs new file mode 100644 index 0000000..f21d446 --- /dev/null +++ b/src/models/modnet/impl.rs @@ -0,0 +1,90 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct MODNet { + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, +} + +impl MODNet { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + ts, + spec, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { + let (h1, w1) = self.processor.image0s_size[idx]; + + let luma = luma.mapv(|x| (x * 255.0) as u8); + let luma = Ops::resize_luma8_u8( + &luma.into_raw_vec_and_offset().0, + self.width as _, + self.height as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } +} diff --git a/src/models/modnet/mod.rs b/src/models/modnet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/modnet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/picodet/README.md b/src/models/picodet/README.md new file mode 100644 index 0000000..fa98440 --- /dev/null +++ b/src/models/picodet/README.md @@ -0,0 +1,9 @@ +# PP-PicoDet: A Better Real-Time Object Detector on Mobile Devices + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) + +## Example + +Refer to the [example](../../../examples/picodet-layout) diff --git a/src/models/picodet/config.rs b/src/models/picodet/config.rs new file mode 100644 index 0000000..ebebdc0 --- /dev/null +++ b/src/models/picodet/config.rs @@ -0,0 +1,61 @@ +use crate::{ResizeMode, COCO_CLASS_NAMES_80}; + +/// Model configuration for `PicoDet` +impl crate::Options { + pub fn picodet() -> Self { + Self::default() + .with_model_name("picodet") + .with_batch_size(1) // TODO: ONNX model's batch size seems always = 1 + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_model_ixx(1, 0, (1, 1, 8).into()) + .with_model_ixx(1, 1, 2.into()) + .with_resize_mode(ResizeMode::FitAdaptive) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_class_confs(&[0.5]) + } + + pub fn picodet_l_coco() -> Self { + Self::picodet() + .with_model_file("l-coco.onnx") + .with_class_names(&COCO_CLASS_NAMES_80) + } + + pub fn picodet_layout_1x() -> Self { + Self::picodet() + .with_model_file("layout-1x.onnx") + .with_class_names(&["Text", "Title", "List", "Table", "Figure"]) + } + + pub fn picodet_l_layout_3cls() -> Self { + Self::picodet() + .with_model_file("l-layout-3cls.onnx") + .with_class_names(&["image", "table", "seal"]) + } + + pub fn picodet_l_layout_17cls() -> Self { + Self::picodet() + .with_model_file("l-layout-17cls.onnx") + .with_class_names(&[ + "paragraph_title", + "image", + "text", + "number", + "abstract", + "content", + "figure_title", + "formula", + "table", + "table_title", + "reference", + "doc_title", + "footnote", + "header", + "algorithm", + "footer", + "seal", + ]) + } +} diff --git a/src/models/picodet/impl.rs b/src/models/picodet/impl.rs new file mode 100644 index 0000000..5bada86 --- /dev/null +++ b/src/models/picodet/impl.rs @@ -0,0 +1,111 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; +use rayon::prelude::*; + +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; + +#[derive(Debug, Builder)] +pub struct PicoDet { + engine: Engine, + height: usize, + width: usize, + batch: usize, + spec: String, + names: Vec, + confs: DynConf, + ts: Ts, + processor: Processor, +} + +impl PicoDet { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + ); + let spec = engine.spec().to_owned(); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let names = options + .class_names() + .expect("No class names are specified.") + .to_vec(); + let confs = DynConf::new(options.class_confs(), names.len()); + + Ok(Self { + engine, + height, + width, + batch, + spec, + names, + confs, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x1 = self.processor.process_images(xs)?; + let x2: X = self.processor.scale_factors_hw.clone().try_into()?; + + Ok(Xs::from(vec![x1, x2])) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + // ONNX models exported by paddle2onnx + // TODO: ONNX model's batch size seems always = 1 + // xs[0] : n, 6 + // xs[1] : n + let y_bboxes: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(_i, pred)| { + let (class_id, confidence) = (pred[0] as usize, pred[1]); + if confidence < self.confs[class_id] { + return None; + } + let (x1, y1, x2, y2) = (pred[2], pred[3], pred[4], pred[5]); + + Some( + Bbox::default() + .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) + .with_confidence(confidence) + .with_id(class_id as isize) + .with_name(&self.names[class_id]), + ) + }) + .collect(); + + let mut y = Y::default(); + if !y_bboxes.is_empty() { + y = y.with_bboxes(&y_bboxes); + } + + Ok(vec![y].into()) + } +} diff --git a/src/models/picodet/mod.rs b/src/models/picodet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/picodet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/pipeline/basemodel.rs b/src/models/pipeline/basemodel.rs new file mode 100644 index 0000000..52b73fc --- /dev/null +++ b/src/models/pipeline/basemodel.rs @@ -0,0 +1,148 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{ + elapsed, DType, Device, Engine, Kind, Options, Processor, Scale, Task, Ts, Version, Xs, X, +}; + +#[derive(Debug, Builder)] +pub struct BaseModelVisual { + engine: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + ts: Ts, + spec: String, + name: &'static str, + device: Device, + dtype: DType, + task: Option, + scale: Option, + kind: Option, + version: Option, +} + +impl BaseModelVisual { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let err_msg = "You need to specify the image height and image width for visual model."; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().expect(err_msg).opt(), + engine.try_width().expect(err_msg).opt(), + engine.ts.clone(), + engine.spec().to_owned(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let device = options.model_device; + let task = options.model_task; + let scale = options.model_scale; + let dtype = options.model_dtype; + let kind = options.model_kind; + let name = options.model_name; + let version = options.model_version; + + Ok(Self { + engine, + height, + width, + batch, + processor, + ts, + spec, + dtype, + task, + scale, + kind, + device, + version, + name, + }) + } + + pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + self.batch = xs.len(); // update + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + + Ok(xs[0].to_owned()) + } +} + +#[derive(Debug, Builder)] +pub struct BaseModelTextual { + engine: Engine, + batch: usize, + processor: Processor, + ts: Ts, + spec: String, + name: &'static str, + device: Device, + dtype: DType, + task: Option, + scale: Option, + kind: Option, + version: Option, +} + +impl BaseModelTextual { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, ts, spec) = ( + engine.batch().opt(), + engine.ts.clone(), + engine.spec().to_owned(), + ); + let processor = options.to_processor()?; + let device = options.model_device; + let task = options.model_task; + let scale = options.model_scale; + let dtype = options.model_dtype; + let kind = options.model_kind; + let name = options.model_name; + let version = options.model_version; + + Ok(Self { + engine, + batch, + processor, + ts, + spec, + dtype, + task, + scale, + kind, + device, + version, + name, + }) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } +} diff --git a/src/models/pipeline/image_classifier.rs b/src/models/pipeline/image_classifier.rs new file mode 100644 index 0000000..25ccfaa --- /dev/null +++ b/src/models/pipeline/image_classifier.rs @@ -0,0 +1,125 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; +use rayon::prelude::*; + +use crate::{elapsed, DynConf, Engine, Options, Prob, Processor, Ts, Xs, Ys, Y}; + +#[derive(Debug, Builder)] +pub struct ImageClassifier { + engine: Engine, + height: usize, + width: usize, + batch: usize, + apply_softmax: bool, + ts: Ts, + processor: Processor, + confs: DynConf, + nc: usize, + names: Vec, + spec: String, +} + +impl TryFrom for ImageClassifier { + type Error = anyhow::Error; + + fn try_from(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&224.into()).opt(), + engine.try_width().unwrap_or(&224.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let (nc, names) = match (options.nc(), options.class_names()) { + (Some(nc), Some(names)) => { + if nc != names.len() { + anyhow::bail!( + "The length of the input class names: {} is inconsistent with the number of classes: {}.", + names.len(), + nc + ); + } + (nc, names.to_vec()) + } + (Some(nc), None) => ( + nc, + (0..nc).map(|x| format!("# {}", x)).collect::>(), + ), + (None, Some(names)) => (names.len(), names.to_vec()), + (None, None) => { + anyhow::bail!("Neither class names nor class numbers were specified."); + } + }; + let confs = DynConf::new(options.class_confs(), nc); + let apply_softmax = options.apply_softmax.unwrap_or_default(); + + Ok(Self { + engine, + height, + width, + batch, + nc, + ts, + spec, + processor, + confs, + names, + apply_softmax, + }) + } +} + +impl ImageClassifier { + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + fn postprocess(&self, xs: Xs) -> Result { + let ys: Ys = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .filter_map(|logits| { + let logits = if self.apply_softmax { + let exps = logits.mapv(|x| x.exp()); + let stds = exps.sum_axis(Axis(0)); + exps / stds + } else { + logits.into_owned() + }; + let probs = Prob::default() + .with_probs(&logits.into_raw_vec_and_offset().0) + .with_names(&self.names.iter().map(|x| x.as_str()).collect::>()); + + Some(Y::default().with_probs(probs)) + }) + .collect::>() + .into(); + + Ok(ys) + } +} diff --git a/src/models/pipeline/mod.rs b/src/models/pipeline/mod.rs new file mode 100644 index 0000000..2ece115 --- /dev/null +++ b/src/models/pipeline/mod.rs @@ -0,0 +1,5 @@ +mod basemodel; +mod image_classifier; + +pub use basemodel::*; +pub use image_classifier::*; diff --git a/src/models/rtdetr/README.md b/src/models/rtdetr/README.md new file mode 100644 index 0000000..99f7ec8 --- /dev/null +++ b/src/models/rtdetr/README.md @@ -0,0 +1,9 @@ +# RT-DETR: DETRs Beat YOLOs on Real-time Object Detection + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/lyuwenyu/RT-DETR) + +## Example + +Refer to the [example](../../../examples/rtdetr) diff --git a/src/models/rtdetr/config.rs b/src/models/rtdetr/config.rs new file mode 100644 index 0000000..2e0668e --- /dev/null +++ b/src/models/rtdetr/config.rs @@ -0,0 +1,40 @@ +use crate::COCO_CLASS_NAMES_80; + +/// Model configuration for `RT-DETR` +impl crate::Options { + pub fn rtdetr() -> Self { + Self::default() + .with_model_name("rtdetr") + .with_batch_size(1) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_normalize(true) + .with_class_confs(&[0.5]) + .with_class_names(&COCO_CLASS_NAMES_80) + } + + pub fn rtdetr_v1_r18vd_coco() -> Self { + Self::rtdetr().with_model_file("v1-r18vd-coco.onnx") + } + + pub fn rtdetr_v2_s_coco() -> Self { + Self::rtdetr().with_model_file("v2-s-coco.onnx") + } + + pub fn rtdetr_v2_ms_coco() -> Self { + Self::rtdetr().with_model_file("v2-ms-coco.onnx") + } + + pub fn rtdetr_v2_m_coco() -> Self { + Self::rtdetr().with_model_file("v2-m-coco.onnx") + } + + pub fn rtdetr_v2_l_coco() -> Self { + Self::rtdetr().with_model_file("v2-l-coco.onnx") + } + + pub fn rtdetr_v2_x_coco() -> Self { + Self::rtdetr().with_model_file("v2-x-coco.onnx") + } +} diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs new file mode 100644 index 0000000..70e5262 --- /dev/null +++ b/src/models/rtdetr/impl.rs @@ -0,0 +1,128 @@ +use crate::{elapsed, Bbox, DynConf, Engine, Processor, Ts, Xs, Ys, X, Y}; +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::Options; + +#[derive(Debug, Builder)] +pub struct RTDETR { + engine: Engine, + height: usize, + width: usize, + batch: usize, + names: Vec, + confs: DynConf, + ts: Ts, + processor: Processor, + spec: String, +} + +impl RTDETR { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + ); + let spec = engine.spec().to_owned(); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let names = options + .class_names() + .expect("No class names specified.") + .to_vec(); + let confs = DynConf::new(options.class_confs(), names.len()); + + Ok(Self { + engine, + height, + width, + batch, + spec, + names, + confs, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x1 = self.processor.process_images(xs)?; + let x2 = X::from(vec![self.height as f32, self.width as f32]) + .insert_axis(0)? + .repeat(0, self.batch)?; + + let xs = Xs::from(vec![x1, x2]); + + Ok(xs) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let ys: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .zip(xs[1].axis_iter(Axis(0)).into_par_iter()) + .zip(xs[2].axis_iter(Axis(0)).into_par_iter()) + .enumerate() + .filter_map(|(idx, ((labels, boxes), scores))| { + let ratio = self.processor.scale_factors_hw[idx][0]; + + let mut y_bboxes = Vec::new(); + for (i, &score) in scores.iter().enumerate() { + let class_id = labels[i] as usize; + if score < self.confs[class_id] { + continue; + } + + let xyxy = boxes.slice(s![i, ..]); + let (x1, y1, x2, y2) = ( + xyxy[0] / ratio, + xyxy[1] / ratio, + xyxy[2] / ratio, + xyxy[3] / ratio, + ); + + y_bboxes.push( + Bbox::default() + .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) + .with_confidence(score) + .with_id(class_id as isize) + .with_name(&self.names[class_id]), + ); + } + + let mut y = Y::default(); + if !y_bboxes.is_empty() { + y = y.with_bboxes(&y_bboxes); + } + + Some(y) + }) + .collect(); + + Ok(ys.into()) + } +} diff --git a/src/models/rtdetr/mod.rs b/src/models/rtdetr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/rtdetr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/rtmo/README.md b/src/models/rtmo/README.md new file mode 100644 index 0000000..25afeb3 --- /dev/null +++ b/src/models/rtmo/README.md @@ -0,0 +1,9 @@ +# RTMO: Towards High-Performance One-Stage Real-Time Multi-Person Pose Estimation + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) + +## Example + +Refer to the [example](../../../examples/rtmo) diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs new file mode 100644 index 0000000..d223269 --- /dev/null +++ b/src/models/rtmo/config.rs @@ -0,0 +1,28 @@ +/// Model configuration for `RTMO` +impl crate::Options { + pub fn rtmo() -> Self { + Self::default() + .with_model_name("rtmo") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_normalize(false) + .with_nk(17) + .with_class_confs(&[0.3]) + .with_keypoint_confs(&[0.5]) + } + + pub fn rtmo_s() -> Self { + Self::rtmo().with_model_file("s.onnx") + } + + pub fn rtmo_m() -> Self { + Self::rtmo().with_model_file("m.onnx") + } + + pub fn rtmo_l() -> Self { + Self::rtmo().with_model_file("l.onnx") + } +} diff --git a/src/models/rtmo.rs b/src/models/rtmo/impl.rs similarity index 55% rename from src/models/rtmo.rs rename to src/models/rtmo/impl.rs index 1ae4b4d..f23f448 100644 --- a/src/models/rtmo.rs +++ b/src/models/rtmo/impl.rs @@ -1,75 +1,87 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, Xs, X, Y}; +use crate::{elapsed, Bbox, DynConf, Engine, Keypoint, Options, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct RTMO { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, confs: DynConf, kconfs: DynConf, } impl RTMO { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), ); - let nc = 1; - let nk = options.nk.unwrap_or(17); - let confs = DynConf::new(&options.confs, nc); - let kconfs = DynConf::new(&options.kconfs, nk); - engine.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let nk = options.nk().unwrap_or(17); + let confs = DynConf::new(options.class_confs(), 1); + let kconfs = DynConf::new(options.keypoint_confs(), nk); Ok(Self { engine, - confs, - kconfs, height, width, batch, + ts, + spec, + processor, + confs, + kconfs, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - )? - .nhwc2nchw()?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); - let (preds_bboxes, preds_kpts) = if xs[0].ndim() == 3 { - (&xs[0], &xs[1]) - } else { - (&xs[1], &xs[0]) - }; + // let (preds_bboxes, preds_kpts) = (&xs["dets"], &xs["keypoints"]); + let (preds_bboxes, preds_kpts) = (&xs[0], &xs[1]); for (idx, (batch_bboxes, batch_kpts)) in preds_bboxes .axis_iter(Axis(0)) .zip(preds_kpts.axis_iter(Axis(0))) .enumerate() { - let width_original = xs0[idx].width() as f32; - let height_original = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / width_original).min(self.height() as f32 / height_original); + let (height_original, width_original) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; let mut y_bboxes = Vec::new(); let mut y_kpts: Vec> = Vec::new(); @@ -90,8 +102,8 @@ impl RTMO { y_bboxes.push( Bbox::default() .with_xyxy( - x1.max(0.0f32).min(width_original), - y1.max(0.0f32).min(height_original), + x1.max(0.0f32).min(width_original as _), + y1.max(0.0f32).min(height_original as _), x2, y2, ) @@ -114,8 +126,8 @@ impl RTMO { .with_id(i as isize) .with_confidence(c) .with_xy( - x.max(0.0f32).min(width_original), - y.max(0.0f32).min(height_original), + x.max(0.0f32).min(width_original as _), + y.max(0.0f32).min(height_original as _), ), ); } @@ -124,18 +136,7 @@ impl RTMO { } ys.push(Y::default().with_bboxes(&y_bboxes).with_keypoints(&y_kpts)); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/rtmo/mod.rs b/src/models/rtmo/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/rtmo/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/sam/README.md b/src/models/sam/README.md new file mode 100644 index 0000000..d2fd9a8 --- /dev/null +++ b/src/models/sam/README.md @@ -0,0 +1,16 @@ +# Segment Anything Model + +## Official Repository + +The official repository can be found on: + +- [sam1](https://github.com/facebookresearch/segment-anything) +- [sam2](https://github.com/facebookresearch/sam2) +- [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) +- [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) +- [sam-hq](https://github.com/SysCV/sam-hq) + + +## Example + +Refer to the [example](../../../examples/sam) diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs new file mode 100644 index 0000000..0e9ce58 --- /dev/null +++ b/src/models/sam/config.rs @@ -0,0 +1,100 @@ +use crate::{models::SamKind, Options}; + +/// Model configuration for `Segment Anything Model` +impl Options { + pub fn sam() -> Self { + Self::default() + .with_model_name("sam") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn sam_encoder() -> Self { + Self::sam() + .with_model_ixx(0, 2, 1024.into()) + .with_model_ixx(0, 3, 1024.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("Bilinear") + .with_image_mean(&[123.5, 116.5, 103.5]) + .with_image_std(&[58.5, 57.0, 57.5]) + .with_normalize(false) + .with_sam_kind(SamKind::Sam) + .with_low_res_mask(false) + .with_find_contours(true) + } + + pub fn sam_decoder() -> Self { + Self::sam() + } + + pub fn sam_v1_base_encoder() -> Self { + Self::sam_encoder().with_model_file("sam-vit-b-encoder.onnx") + } + + pub fn sam_v1_base_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-vit-b-decoder.onnx") + } + + pub fn sam_v1_base_singlemask_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-vit-b-decoder-singlemask.onnx") + } + + pub fn sam2_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-tiny-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-tiny-decoder.onnx") + } + + pub fn sam2_small_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-small-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_small_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-small-decoder.onnx") + } + + pub fn sam2_base_plus_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-base-plus-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_base_plus_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-base-plus-decoder.onnx") + } + + pub fn mobile_sam_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("mobile-sam-vit-t-encoder.onnx") + .with_sam_kind(SamKind::MobileSam) + } + + pub fn mobile_sam_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("mobile-sam-vit-t-decoder.onnx") + } + + pub fn sam_hq_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam-hq-vit-t-encoder.onnx") + .with_sam_kind(SamKind::SamHq) + } + + pub fn sam_hq_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-hq-vit-t-decoder.onnx") + } + + pub fn edge_sam_3x_encoder() -> Self { + Self::sam_encoder() + .with_model_file("edge-sam-3x-encoder.onnx") + .with_sam_kind(SamKind::EdgeSam) + } + + pub fn edge_sam_3x_decoder() -> Self { + Self::sam_decoder().with_model_file("edge-sam-3x-decoder.onnx") + } +} diff --git a/src/models/sam.rs b/src/models/sam/impl.rs similarity index 73% rename from src/models/sam.rs rename to src/models/sam/impl.rs index bcd12bd..41eb7ba 100644 --- a/src/models/sam.rs +++ b/src/models/sam/impl.rs @@ -1,11 +1,12 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array, Axis}; use rand::prelude::*; -use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Mask, Ops, Options, Polygon, Processor, Ts, Xs, Ys, X, Y}; -#[derive(Debug, Clone, clap::ValueEnum)] +#[derive(Debug, Clone)] pub enum SamKind { Sam, Sam2, @@ -14,6 +15,21 @@ pub enum SamKind { EdgeSam, } +impl TryFrom<&str> for SamKind { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "sam" => Ok(Self::Sam), + "sam2" => Ok(Self::Sam2), + "mobilesam" | "mobile-sam" => Ok(Self::MobileSam), + "samhq" | "sam-hq" => Ok(Self::SamHq), + "edgesam" | "edge-sam" => Ok(Self::EdgeSam), + x => anyhow::bail!("Unsupported SamKind: {}", x), + } + } +} + #[derive(Debug, Default, Clone)] pub struct SamPrompt { points: Vec, @@ -62,87 +78,81 @@ impl SamPrompt { } } -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct SAM { - encoder: OrtEngine, - decoder: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - pub conf: DynConf, + encoder: Engine, + decoder: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + conf: DynConf, find_contours: bool, kind: SamKind, use_low_res_mask: bool, + ts: Ts, + spec: String, } impl SAM { pub fn new(options_encoder: Options, options_decoder: Options) -> Result { - let mut encoder = OrtEngine::new(&options_encoder)?; - let mut decoder = OrtEngine::new(&options_decoder)?; + let encoder = options_encoder.to_engine()?; + let decoder = options_decoder.to_engine()?; let (batch, height, width) = ( - encoder.inputs_minoptmax()[0][0].to_owned(), - encoder.inputs_minoptmax()[0][2].to_owned(), - encoder.inputs_minoptmax()[0][3].to_owned(), + encoder.batch().opt(), + encoder.try_height().unwrap_or(&1024.into()).opt(), + encoder.try_width().unwrap_or(&1024.into()).opt(), ); - let conf = DynConf::new(&options_decoder.confs, 1); + let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); + let spec = encoder.spec().to_owned(); + + let processor = options_encoder + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); - let kind = match options_decoder.sam_kind { + let conf = DynConf::new(options_encoder.class_confs(), 1); + let find_contours = options_encoder.find_contours; + let kind = match options_encoder.sam_kind { Some(x) => x, None => anyhow::bail!("Error: no clear `SamKind` specified."), }; - let find_contours = options_decoder.find_contours; let use_low_res_mask = match kind { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { - options_decoder.use_low_res_mask.unwrap_or(false) + options_encoder.low_res_mask.unwrap_or(false) } SamKind::EdgeSam | SamKind::Sam2 => true, }; - encoder.dry_run()?; - decoder.dry_run()?; - Ok(Self { encoder, decoder, + conf, batch, height, width, - conf, + ts, + processor, kind, find_contours, use_low_res_mask, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result> { - let ys = self.encode(xs)?; - self.decode(&ys, xs, prompts) + pub fn forward(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result { + let ys = elapsed!("encode", self.ts, { self.encode(xs)? }); + let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? }); + + Ok(ys) } pub fn encode(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 0, - "auto", - false, - ), - Ops::Standardize(&[123.675, 116.28, 103.53], &[58.395, 57.12, 57.375], 3), - Ops::Nhwc2nchw, - ])?; - + let xs_ = self.processor.process_images(xs)?; self.encoder.run(Xs::from(xs_)) } - pub fn decode( - &mut self, - xs: &Xs, - xs0: &[DynamicImage], - prompts: &[SamPrompt], - ) -> Result> { + pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result { let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind { SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])), _ => (&xs[0], None, None), @@ -150,44 +160,43 @@ impl SAM { let mut ys: Vec = Vec::new(); for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; + let args = match self.kind { SamKind::Sam | SamKind::MobileSam => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // image_embedding + .repeat(0, self.batch)?, // image_embedding prompts[idx].point_coords(ratio)?, // point_coords prompts[idx].point_labels()?, // point_labels X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input, - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } SamKind::SamHq => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // image_embedding + .repeat(0, self.batch)?, // image_embedding X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned()) .insert_axis(0)? .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // intern_embedding + .repeat(0, self.batch)?, // intern_embedding prompts[idx].point_coords(ratio)?, // point_coords prompts[idx].point_labels()?, // point_labels X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } SamKind::EdgeSam => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, prompts[idx].point_coords(ratio)?, prompts[idx].point_labels()?, ] @@ -196,7 +205,7 @@ impl SAM { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, X::from( high_res_features_0 .unwrap() @@ -205,7 +214,7 @@ impl SAM { .into_owned(), ) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, X::from( high_res_features_1 .unwrap() @@ -214,12 +223,12 @@ impl SAM { .into_owned(), ) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, prompts[idx].point_coords(ratio)?, prompts[idx].point_labels()?, X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } }; @@ -310,26 +319,14 @@ impl SAM { ys.push(y); } - Ok(ys) + Ok(ys.into()) } pub fn width_low_res(&self) -> usize { - self.width() as usize / 4 + self.width / 4 } pub fn height_low_res(&self) -> usize { - self.height() as usize / 4 - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ + self.height / 4 } } diff --git a/src/models/sam/mod.rs b/src/models/sam/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/sam/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/sapiens/README.md b/src/models/sapiens/README.md new file mode 100644 index 0000000..1fabbaf --- /dev/null +++ b/src/models/sapiens/README.md @@ -0,0 +1,22 @@ +# Sapiens: Foundation for Human Vision Models + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/facebookresearch/sapiens) + +## TODO + +- [x] Body-Part Segmentation +- [ ] Pose Estimation +- [ ] Depth Estimation +- [ ] Surface Normal Estimation + +## Example + +Refer to the [example](../../../examples/sapiens) + + + + + + diff --git a/src/models/sapiens/config.rs b/src/models/sapiens/config.rs new file mode 100644 index 0000000..e36c33e --- /dev/null +++ b/src/models/sapiens/config.rs @@ -0,0 +1,47 @@ +use crate::BODY_PARTS_NAMES_28; + +/// Model configuration for `Sapiens` +impl crate::Options { + pub fn sapiens() -> Self { + Self::default() + .with_model_name("sapiens") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 1024.into()) + .with_model_ixx(0, 3, 768.into()) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_resize_filter("Bilinear") + .with_image_mean(&[123.5, 116.5, 103.5]) + .with_image_std(&[58.5, 57.0, 57.5]) + .with_normalize(false) + } + + pub fn sapiens_body_part_segmentation() -> Self { + Self::sapiens() + .with_model_task(crate::Task::InstanceSegmentation) + .with_class_names(&BODY_PARTS_NAMES_28) + } + + pub fn sapiens_seg_0_3b() -> Self { + Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b.onnx") + } + + // pub fn sapiens_seg_0_3b_uint8() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-uint8.onnx") + // } + + // pub fn sapiens_seg_0_3b_fp16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-fp16.onnx") + // } + + // pub fn sapiens_seg_0_3b_bnb4() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-bnb4.onnx") + // } + + // pub fn sapiens_seg_0_3b_q4f16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-q4f16.onnx") + // } + + // pub fn sapiens_seg_0_6b_fp16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.6b-fp16.onnx") + // } +} diff --git a/src/models/sapiens.rs b/src/models/sapiens/impl.rs similarity index 65% rename from src/models/sapiens.rs rename to src/models/sapiens/impl.rs index c43a6b9..927ca4b 100644 --- a/src/models/sapiens.rs +++ b/src/models/sapiens/impl.rs @@ -1,73 +1,84 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array2, Axis}; -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, Engine, Mask, Ops, Options, Polygon, Processor, Task, Ts, Xs, Ys, Y}; -#[derive(Debug, Clone, clap::ValueEnum)] -pub enum SapiensTask { - Seg, - Depth, - Normal, - Pose, -} - -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct Sapiens { - engine_seg: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - task: SapiensTask, + engine: Engine, + height: usize, + width: usize, + batch: usize, + task: Task, names_body: Option>, + ts: Ts, + processor: Processor, + spec: String, } impl Sapiens { - pub fn new(options_seg: Options) -> Result { - let mut engine_seg = OrtEngine::new(&options_seg)?; - let (batch, height, width) = ( - engine_seg.batch().to_owned(), - engine_seg.height().to_owned(), - engine_seg.width().to_owned(), + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&1024.into()).opt(), + engine.try_width().unwrap_or(&768.into()).opt(), + engine.ts().clone(), ); - let task = options_seg - .sapiens_task - .expect("Error: No sapiens task specified."); - let names_body = options_seg.names; - engine_seg.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let task = options.model_task.expect("No sapiens task specified."); + let names_body = options.class_names; Ok(Self { - engine_seg, + engine, height, width, batch, task, names_body, + ts, + processor, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize(xs, self.height() as u32, self.width() as u32, "Bilinear"), - Ops::Standardize(&[123.5, 116.5, 103.5], &[58.5, 57.0, 57.5], 3), - Ops::Nhwc2nchw, - ])?; - - match self.task { - SapiensTask::Seg => { - let ys = self.engine_seg.run(Xs::from(xs_))?; - self.postprocess_seg(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { + if let Task::InstanceSegmentation = self.task { + self.postprocess_seg(ys)? + } else { + unimplemented!() } - _ => todo!(), - } + }); + + Ok(ys) } - pub fn postprocess_seg(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess_seg(&self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); for (idx, b) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - // rescale + let (h1, w1) = self.processor.image0s_size[idx]; let masks = Ops::interpolate_3d(b.to_owned(), w1 as _, h1 as _, "Bilinear")?; // generate mask @@ -131,7 +142,6 @@ impl Sapiens { Some(p) => p, None => continue, }; - y_polygons.push(polygon); let mut mask = Mask::default().with_mask(luma).with_id(*i as _); @@ -142,18 +152,7 @@ impl Sapiens { } ys.push(Y::default().with_masks(&y_masks).with_polygons(&y_polygons)); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/sapiens/mod.rs b/src/models/sapiens/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/sapiens/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/slanet/README.md b/src/models/slanet/README.md new file mode 100644 index 0000000..d6e8aa2 --- /dev/null +++ b/src/models/slanet/README.md @@ -0,0 +1,9 @@ +# SLANet-LCNetV2 + +## Official Repository + +The official repository can be found on: [GitHub](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) + +## Example + +Refer to the [example](../../../examples/slanet) diff --git a/src/models/slanet/config.rs b/src/models/slanet/config.rs new file mode 100644 index 0000000..f29b311 --- /dev/null +++ b/src/models/slanet/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `SLANet` +impl crate::Options { + pub fn slanet() -> Self { + Self::default() + .with_model_name("slanet") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 2, (320, 488, 488).into()) + .with_model_ixx(0, 3, (320, 488, 488).into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_padding_value(0) + .with_unsigned(true) + } + + pub fn slanet_lcnet_v2_mobile_ch() -> Self { + Self::slanet() + .with_model_file("v2-mobile-ch.onnx") + .with_vocab_txt("vocab-sla-v2.txt") + } +} diff --git a/src/models/slanet/impl.rs b/src/models/slanet/impl.rs new file mode 100644 index 0000000..cfbd50c --- /dev/null +++ b/src/models/slanet/impl.rs @@ -0,0 +1,109 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; + +use crate::{elapsed, models::BaseModelVisual, Keypoint, Options, Text, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct SLANet { + base: BaseModelVisual, + td_tokens: Vec<&'static str>, + eos: usize, + sos: usize, + ts: Ts, + spec: String, +} + +impl SLANet { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let base = BaseModelVisual::new(options)?; + let spec = base.engine().spec().to_owned(); + let sos = 0; + let eos = base.processor().vocab().len() - 1; + let td_tokens = vec!["", ""]; + let ts = base.ts().clone(); + + Ok(Self { + base, + td_tokens, + eos, + sos, + ts, + spec, + }) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.base.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.base.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + fn postprocess(&self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (bid, (bboxes, structures)) in xs[0] + .axis_iter(Axis(0)) + .zip(xs[1].axis_iter(Axis(0))) + .enumerate() + { + let mut y_texts: Vec = vec!["".into(), "".into(), "".into()]; + let mut y_kpts: Vec> = Vec::new(); + let (image_height, image_width) = self.base.processor().image0s_size[bid]; + for (i, structure) in structures.axis_iter(Axis(0)).enumerate() { + let (token_id, &_confidence) = match structure + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1)) + { + None => continue, + Some((id, conf)) => (id, conf), + }; + if token_id == self.eos { + break; + } + if token_id == self.sos { + continue; + } + + // token + let token = self.base.processor().vocab()[token_id].as_str(); + + // keypoint + if self.td_tokens.contains(&token) { + let slice_bboxes = bboxes.slice(s![i, ..]); + let x14 = slice_bboxes + .slice(s![0..;2]) + .mapv(|x| x * image_width as f32); + let y14 = slice_bboxes + .slice(s![1..;2]) + .mapv(|x| x * image_height as f32); + y_kpts.push( + (0..=3) + .map(|i| (x14[i], y14[i], i as isize).into()) + .collect(), + ); + } + + y_texts.push(token.into()); + } + + // clean up text + if y_texts.len() == 3 { + y_texts.clear(); + } else { + y_texts.extend_from_slice(&["
".into(), "".into(), "".into()]); + } + + ys.push(Y::default().with_keypoints(&y_kpts).with_texts(&y_texts)); + } + + Ok(ys.into()) + } +} diff --git a/src/models/slanet/mod.rs b/src/models/slanet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/slanet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/svtr.rs b/src/models/svtr.rs deleted file mode 100644 index 8f25880..0000000 --- a/src/models/svtr.rs +++ /dev/null @@ -1,101 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -use crate::{DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct SVTR { - engine: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch: MinOptMax, - confs: DynConf, - vocab: Vec, -} - -impl SVTR { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - let confs = DynConf::new(&options.confs, 1); - let mut vocab: Vec<_> = - std::fs::read_to_string(options.vocab.expect("No vocabulary found"))? - .lines() - .map(|line| line.to_string()) - .collect(); - vocab.push(" ".to_string()); - vocab.insert(0, "Blank".to_string()); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - vocab, - confs, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - 0, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys) - } - - pub fn postprocess(&self, xs: Xs) -> Result> { - let mut ys: Vec = Vec::new(); - for batch in xs[0].axis_iter(Axis(0)) { - let preds = batch - .axis_iter(Axis(0)) - .filter_map(|x| { - x.into_iter() - .enumerate() - .max_by(|(_, x), (_, y)| x.total_cmp(y)) - }) - .collect::>(); - - let text = preds - .iter() - .enumerate() - .fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| { - if *text_id == 0 || confidence < self.confs[0] { - return text_ids; - } - - if idx == 0 || idx == self.vocab.len() - 1 { - return text_ids; - } - - if *text_id != preds[idx - 1].0 { - text_ids.push(*text_id); - } - text_ids - }) - .into_iter() - .map(|idx| self.vocab[idx].to_owned()) - .collect::(); - - ys.push(Y::default().with_texts(&[text])) - } - Ok(ys) - } -} diff --git a/src/models/svtr/README.md b/src/models/svtr/README.md new file mode 100644 index 0000000..08e45be --- /dev/null +++ b/src/models/svtr/README.md @@ -0,0 +1,9 @@ +# SVTR: Scene Text Recognition with a Single Visual Model + +## Official Repository + +The official repository can be found on: [SVTRv2](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html) + +## Example + +Refer to the [example](../../../examples/svtr) \ No newline at end of file diff --git a/src/models/svtr/config.rs b/src/models/svtr/config.rs new file mode 100644 index 0000000..93fc38e --- /dev/null +++ b/src/models/svtr/config.rs @@ -0,0 +1,43 @@ +/// Model configuration for `SVTR` +impl crate::Options { + pub fn svtr() -> Self { + Self::default() + .with_model_name("svtr") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 2, 48.into()) + .with_model_ixx(0, 3, (320, 960, 1600).into()) + .with_resize_mode(crate::ResizeMode::FitHeight) + .with_padding_value(0) + .with_normalize(true) + .with_class_confs(&[0.2]) + .with_vocab_txt("vocab-v1-ppocr-rec-ch.txt") + } + + pub fn ppocr_rec_v3_ch() -> Self { + Self::svtr().with_model_file("ppocr-v3-ch.onnx") + } + + pub fn ppocr_rec_v4_ch() -> Self { + Self::svtr().with_model_file("ppocr-v4-ch.onnx") + } + + pub fn ppocr_rec_v4_server_ch() -> Self { + Self::svtr().with_model_file("ppocr-v4-server-ch.onnx") + } + + pub fn svtr_v2_server_ch() -> Self { + Self::svtr().with_model_file("v2-server-ch.onnx") + } + + pub fn repsvtr_ch() -> Self { + Self::svtr().with_model_file("repsvtr-ch.onnx") + } + + pub fn svtr_v2_teacher_ch() -> Self { + Self::svtr().with_model_file("v2-distill-teacher-ch.onnx") + } + + pub fn svtr_v2_student_ch() -> Self { + Self::svtr().with_model_file("v2-distill-student-ch.onnx") + } +} diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs new file mode 100644 index 0000000..f14f5f8 --- /dev/null +++ b/src/models/svtr/impl.rs @@ -0,0 +1,109 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct SVTR { + engine: Engine, + height: usize, + width: usize, + batch: usize, + confs: DynConf, + spec: String, + ts: Ts, + processor: Processor, +} + +impl SVTR { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&960.into()).opt(), + engine.try_width().unwrap_or(&960.into()).opt(), + engine.ts.clone(), + ); + let spec = options.model_spec().to_string(); + let confs = DynConf::new(options.class_confs(), 1); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + if processor.vocab().is_empty() { + anyhow::bail!("No vocab file found") + } + + Ok(Self { + engine, + height, + width, + batch, + confs, + processor, + spec, + ts, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess(&self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for batch in xs[0].axis_iter(Axis(0)) { + let preds = batch + .axis_iter(Axis(0)) + .filter_map(|x| { + x.into_iter() + .enumerate() + .max_by(|(_, x), (_, y)| x.total_cmp(y)) + }) + .collect::>(); + + let text = preds + .iter() + .enumerate() + .fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| { + if *text_id == 0 || confidence < self.confs[0] { + return text_ids; + } + + if idx == 0 || idx == self.processor.vocab().len() - 1 { + return text_ids; + } + + if *text_id != preds[idx - 1].0 { + text_ids.push(*text_id); + } + text_ids + }) + .into_iter() + .map(|idx| self.processor.vocab()[idx].to_owned()) + .collect::(); + + ys.push(Y::default().with_texts(&[text.into()])) + } + + Ok(ys.into()) + } +} diff --git a/src/models/svtr/mod.rs b/src/models/svtr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/svtr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/trocr/README.md b/src/models/trocr/README.md new file mode 100644 index 0000000..d6e7870 --- /dev/null +++ b/src/models/trocr/README.md @@ -0,0 +1,9 @@ +# TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models + +## Official Repository + +The official repository can be found on: [Hugging Face](https://huggingface.co/microsoft/trocr-base-printed) + +## Example + +Refer to the [example](../../../examples/trocr) diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs new file mode 100644 index 0000000..8343434 --- /dev/null +++ b/src/models/trocr/config.rs @@ -0,0 +1,92 @@ +use crate::Scale; + +/// Model configuration for `TrOCR` +impl crate::Options { + pub fn trocr() -> Self { + Self::default().with_model_name("trocr").with_batch_size(1) + } + + pub fn trocr_visual() -> Self { + Self::trocr() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 384.into()) + .with_model_ixx(0, 3, 384.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn trocr_textual() -> Self { + Self::trocr().with_model_kind(crate::Kind::Language) + } + + pub fn trocr_visual_small() -> Self { + Self::trocr_visual().with_model_scale(Scale::S) + } + + pub fn trocr_textual_small() -> Self { + Self::trocr_textual() + .with_model_scale(Scale::S) + .with_tokenizer_file("trocr/tokenizer-small.json") + } + + pub fn trocr_visual_base() -> Self { + Self::trocr_visual().with_model_scale(Scale::B) + } + + pub fn trocr_textual_base() -> Self { + Self::trocr_textual() + .with_model_scale(Scale::B) + .with_tokenizer_file("trocr/tokenizer-base.json") + } + + pub fn trocr_encoder_small_printed() -> Self { + Self::trocr_visual_small().with_model_file("s-encoder-printed.onnx") + } + + pub fn trocr_decoder_small_printed() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-printed.onnx") + } + + pub fn trocr_decoder_merged_small_printed() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-merged-printed.onnx") + } + + pub fn trocr_encoder_small_handwritten() -> Self { + Self::trocr_visual_small().with_model_file("s-encoder-handwritten.onnx") + } + + pub fn trocr_decoder_small_handwritten() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-handwritten.onnx") + } + + pub fn trocr_decoder_merged_small_handwritten() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-merged-handwritten.onnx") + } + + pub fn trocr_encoder_base_printed() -> Self { + Self::trocr_visual_base().with_model_file("b-encoder-printed.onnx") + } + + pub fn trocr_decoder_base_printed() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-printed.onnx") + } + + pub fn trocr_decoder_merged_base_printed() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-merged-printed.onnx") + } + + pub fn trocr_encoder_base_handwritten() -> Self { + Self::trocr_visual_base().with_model_file("b-encoder-handwritten.onnx") + } + + pub fn trocr_decoder_base_handwritten() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-handwritten.onnx") + } + + pub fn trocr_decoder_merged_base_handwritten() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-merged-handwritten.onnx") + } +} diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs new file mode 100644 index 0000000..c6894fd --- /dev/null +++ b/src/models/trocr/impl.rs @@ -0,0 +1,351 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use rayon::prelude::*; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual}, + Options, Scale, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Copy, Clone)] +pub enum TrOCRKind { + Printed, + HandWritten, +} + +impl TryFrom<&str> for TrOCRKind { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "printed" => Ok(Self::Printed), + "handwritten" | "hand-written" => Ok(Self::HandWritten), + x => anyhow::bail!("Unsupported TrOCRKind: {}", x), + } + } +} + +#[derive(Debug, Builder)] +pub struct TrOCR { + encoder: BaseModelVisual, + decoder: BaseModelTextual, + decoder_merged: BaseModelTextual, + max_length: u32, + eos_token_id: u32, + decoder_start_token_id: u32, + ts: Ts, + n_kvs: usize, +} + +impl TrOCR { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new( + options_encoder: Options, + options_decoder: Options, + options_decoder_merged: Options, + ) -> Result { + let encoder = BaseModelVisual::new(options_encoder)?; + let decoder = BaseModelTextual::new(options_decoder)?; + let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; + let ts = Ts::merge(&[ + encoder.engine().ts(), + decoder.engine().ts(), + decoder_merged.engine().ts(), + ]); + + // "bos_token": "", "eos_token": "", "sep_token": "", + // "model_max_length": 1000000000000000019884624838656, + // let bos_token = ""; + // let eos_token = ""; + // let sep_token = ""; + // let bos_token_id = 0; + // let pad_token_id = 1; + let max_length = 1024; // TODO + let eos_token_id = 2; + let decoder_start_token_id = 2; + let n_kvs = match decoder.scale() { + Some(Scale::S) => 6, + Some(Scale::B) => 12, + _ => unimplemented!(), + }; + + Ok(Self { + encoder, + decoder, + decoder_merged, + max_length, + ts, + eos_token_id, + decoder_start_token_id, + n_kvs, + }) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let encoder_hidden_states = elapsed!("encode", self.ts, { self.encoder.encode(xs)? }); + let generated = elapsed!("generate", self.ts, { + self.generate(&encoder_hidden_states)? + }); + let ys = elapsed!("decode", self.ts, { self.decode(generated)? }); + + Ok(ys) + } + + // fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { + // // input_ids + // let input_ids = X::from(vec![self.decoder_start_token_id as f32]) + // .insert_axis(0)? + // .repeat(0, self.encoder.batch())?; + + // // decoder + // let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + // input_ids.clone(), + // encoder_hidden_states.clone(), + // ]))?; + + // // encoder kvs + // let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + // .step_by(4) + // .flat_map(|i| [i, i + 1]) + // .map(|i| decoder_outputs[i].clone()) + // .collect(); + + // // token ids + // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + // let mut finished = vec![false; self.encoder.batch()]; + // let mut last_tokens: Vec = vec![0.; self.encoder.batch()]; + // let mut logits_sampler = LogitsSampler::new(); + + // // generate + // for _ in 0..self.max_length { + // let logits = &decoder_outputs[0]; + // let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + // .step_by(4) + // .flat_map(|i| [i, i + 1]) + // .map(|i| decoder_outputs[i].clone()) + // .collect(); + + // // decode each token for each batch + // for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + // if !finished[i] { + // let token_id = logits_sampler.decode( + // &logit + // .slice(s![-1, ..]) + // .into_owned() + // .into_raw_vec_and_offset() + // .0, + // )?; + + // if token_id == self.eos_token_id { + // finished[i] = true; + // } else { + // token_ids[i].push(token_id); + // } + + // // update + // last_tokens[i] = token_id as f32; + // } + // } + + // // all finished? + // if finished.iter().all(|&x| x) { + // break; + // } + + // // build inputs + // let input_ids = X::from(last_tokens.clone()).insert_axis(1)?; + // let mut xs = vec![input_ids, encoder_hidden_states.clone()]; + // for i in 0..self.n_kvs { + // xs.push(decoder_kvs[i * 2].clone()); + // xs.push(decoder_kvs[i * 2 + 1].clone()); + // xs.push(encoder_kvs[i * 2].clone()); + // xs.push(encoder_kvs[i * 2 + 1].clone()); + // } + // xs.push(X::ones(&[1])); // use_cache + + // // generate + // decoder_outputs = self.decoder_merged.inference(xs.into())?; + // } + + // Ok(token_ids) + // } + + fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { + // input_ids + let input_ids = X::from(vec![self.decoder_start_token_id as f32]) + .insert_axis(0)? + .repeat(0, self.encoder.batch())?; + + // decoder + let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + input_ids.clone(), + encoder_hidden_states.clone(), + ]))?; + + // encoder kvs + let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // token ids + let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + + // generate + for _ in 0..self.max_length { + let logits = &decoder_outputs[0]; + let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // decode each token for each batch + let (finished, last_tokens) = self.decoder_merged.processor().par_generate( + logits, + &mut token_ids, + self.eos_token_id, + )?; + + if finished { + break; + } + + // build inputs + let input_ids = X::from(last_tokens).insert_axis(1)?; + let mut xs = vec![input_ids, encoder_hidden_states.clone()]; + for i in 0..self.n_kvs { + xs.push(decoder_kvs[i * 2].clone()); + xs.push(decoder_kvs[i * 2 + 1].clone()); + xs.push(encoder_kvs[i * 2].clone()); + xs.push(encoder_kvs[i * 2 + 1].clone()); + } + xs.push(X::ones(&[1])); // use_cache + + // generate + decoder_outputs = self.decoder_merged.inference(xs.into())?; + } + + Ok(token_ids) + } + + pub fn decode(&self, token_ids: Vec>) -> Result { + // decode + let texts = self + .decoder_merged + .processor() + .decode_tokens_batch(&token_ids, false)?; + + // to texts + let texts = texts + .into_par_iter() + .map(|x| Y::default().with_texts(&[x.into()])) + .collect::>() + .into(); + + Ok(texts) + } +} + +// #[derive(Debug, Builder)] +// pub struct TrOCREncoder { +// // TODO: `BaseVisualEncoder`, `BaseVisualModel` struct? +// engine: Engine, +// height: usize, +// width: usize, +// batch: usize, +// processor: Processor, +// } + +// impl TrOCREncoder { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let (batch, height, width) = ( +// engine.batch().opt(), +// engine.try_height().unwrap_or(&384.into()).opt(), +// engine.try_width().unwrap_or(&384.into()).opt(), +// ); +// let processor = options +// .to_processor()? +// .with_image_width(width as _) +// .with_image_height(height as _); + +// Ok(Self { +// engine, +// height, +// width, +// batch, +// processor, +// }) +// } + +// pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { +// self.batch = xs.len(); // TODO +// let x = self.processor.process_images(xs)?; + +// Ok(x.into()) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } + +// fn encode(&mut self, xs: &[DynamicImage]) -> Result { +// // encode a batch of images into one embedding, that's `X` +// let xs = self.preprocess(xs)?; +// let xs = self.inference(xs)?; +// let x = xs[0].to_owned(); + +// Ok(x) +// } +// } + +// #[derive(Debug, Builder)] +// pub struct TrOCRDecoder { +// engine: Engine, +// batch: usize, +// } + +// impl TrOCRDecoder { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let batch = engine.batch().opt(); + +// Ok(Self { engine, batch }) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } +// } + +// #[derive(Debug, Builder)] +// pub struct TrOCRDecoderMerged { +// engine: Engine, +// batch: usize, +// processor: Processor, +// } + +// impl TrOCRDecoderMerged { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let batch = engine.batch().opt(); +// let processor = options.to_processor()?; + +// Ok(Self { +// engine, +// batch, +// processor, +// }) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } +// } diff --git a/src/models/trocr/mod.rs b/src/models/trocr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/trocr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/yolo/README.md b/src/models/yolo/README.md new file mode 100644 index 0000000..576f14b --- /dev/null +++ b/src/models/yolo/README.md @@ -0,0 +1,41 @@ +# YOLO: You Only Look Once + +## Official Repository + +The official repository can be found on: +- [YOLO Series Intro](https://docs.ultralytics.com/models/) +- [YOLOv5](https://github.com/ultralytics/yolov5) +- [YOLOv6](https://github.com/meituan/YOLOv6) +- [YOLOv7](https://github.com/WongKinYiu/yolov7) +- [YOLOv8, YOLO11](https://github.com/ultralytics/ultralytics) +- [YOLOv9](https://github.com/WongKinYiu/yolov9) +- [YOLOv10](https://github.com/THU-MIG/yolov10) + + +## Example + +Refer to the [example](../../../examples/yolo) + + +## TODO +- [x] YOLOv5-det +- [x] YOLOv5-cls +- [x] YOLOv5-seg +- [x] YOLOv6 +- [x] YOLOv7 +- [x] YOLOv8-det +- [x] YOLOv8-cls +- [x] YOLOv8-pose +- [x] YOLOv8-seg +- [x] YOLOv8-obb +- [x] YOLOv8-world +- [x] YOLOv8-rtdetr +- [x] YOLOv9 +- [x] YOLOv10 +- [x] YOLO11-det +- [x] YOLO11-cls +- [x] YOLO11-pose +- [x] YOLO11-seg +- [x] YOLO11-obb +- [x] FastSam +- [ ] YOLO-NAS diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs new file mode 100644 index 0000000..d50dcc4 --- /dev/null +++ b/src/models/yolo/config.rs @@ -0,0 +1,196 @@ +use crate::{models::YOLOPredsFormat, Options, ResizeMode, Scale, Task, COCO_KEYPOINTS_NAMES_17}; + +impl Options { + pub fn yolo() -> Self { + Self::default() + .with_model_name("yolo") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_find_contours(true) + } + + pub fn doclayout_yolo_docstructbench() -> Self { + Self::yolo_v10() + .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1 + .with_model_ixx(0, 2, (640, 1024, 1024).into()) + .with_model_ixx(0, 3, (640, 1024, 1024).into()) + .with_class_confs(&[0.4]) + .with_class_names(&[ + "title", + "plain text", + "abandon", + "figure", + "figure_caption", + "table", + "table_caption", + "table_footnote", + "isolate_formula", + "formula_caption", + ]) + } + + pub fn yolo_classify() -> Self { + Self::yolo() + .with_model_task(Task::ImageClassification) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_resize_mode(ResizeMode::FitExact) + .with_resize_filter("Bilinear") + } + + pub fn yolo_detect() -> Self { + Self::yolo().with_model_task(Task::ObjectDetection) + } + + pub fn yolo_pose() -> Self { + Self::yolo() + .with_model_task(Task::KeypointsDetection) + .with_keypoint_names(&COCO_KEYPOINTS_NAMES_17) + } + + pub fn yolo_segment() -> Self { + Self::yolo().with_model_task(Task::InstanceSegmentation) + } + + pub fn yolo_obb() -> Self { + Self::yolo().with_model_task(Task::OrientedObjectDetection) + } + + pub fn fastsam_s() -> Self { + Self::yolo_segment() + .with_model_scale(Scale::S) + .with_model_version(8.0.into()) + .with_model_file("FastSAM-s.onnx") + } + + pub fn yolo_v8_rtdetr() -> Self { + Self::yolo() + .with_model_version(7.0.into()) + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + } + + pub fn yolo_v8_rtdetr_l() -> Self { + Self::yolo_v8_rtdetr() + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + .with_model_scale(Scale::L) + .with_model_file("rtdetr-l-det.onnx") + } + + pub fn yolo_v8_rtdetr_x() -> Self { + Self::yolo_v8_rtdetr() + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + .with_model_scale(Scale::X) + } + + pub fn yolo_n() -> Self { + Self::yolo().with_model_scale(Scale::N) + } + + pub fn yolo_s() -> Self { + Self::yolo().with_model_scale(Scale::S) + } + + pub fn yolo_m() -> Self { + Self::yolo().with_model_scale(Scale::M) + } + + pub fn yolo_l() -> Self { + Self::yolo().with_model_scale(Scale::L) + } + + pub fn yolo_x() -> Self { + Self::yolo().with_model_scale(Scale::X) + } + + pub fn yolo_v5() -> Self { + Self::yolo().with_model_version(5.0.into()) + } + + pub fn yolo_v6() -> Self { + Self::yolo().with_model_version(6.0.into()) + } + + pub fn yolo_v7() -> Self { + Self::yolo().with_model_version(7.0.into()) + } + + pub fn yolo_v8() -> Self { + Self::yolo().with_model_version(8.0.into()) + } + + pub fn yolo_v9() -> Self { + Self::yolo().with_model_version(9.0.into()) + } + + pub fn yolo_v10() -> Self { + Self::yolo().with_model_version(10.0.into()) + } + + pub fn yolo_v11() -> Self { + Self::yolo().with_model_version(11.0.into()) + } + + pub fn yolo_v8_n() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::N) + } + + pub fn yolo_v8_s() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::S) + } + + pub fn yolo_v8_m() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::M) + } + + pub fn yolo_v8_l() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::L) + } + + pub fn yolo_v8_x() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::X) + } + + pub fn yolo_v11_n() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::N) + } + + pub fn yolo_v11_s() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::S) + } + + pub fn yolo_v11_m() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::M) + } + + pub fn yolo_v11_l() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::L) + } + + pub fn yolo_v11_x() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::X) + } +} diff --git a/src/models/yolo.rs b/src/models/yolo/impl.rs similarity index 58% rename from src/models/yolo.rs rename to src/models/yolo/impl.rs index 3288950..396b602 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo/impl.rs @@ -1,224 +1,296 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; +use log::{error, info}; use ndarray::{s, Array, Axis}; use rayon::prelude::*; use regex::Regex; use crate::{ - Bbox, BoxType, DynConf, Keypoint, Mask, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Prob, - Vision, Xs, YOLOPreds, YOLOTask, YOLOVersion, X, Y, + elapsed, + models::{BoxType, YOLOPredsFormat}, + Bbox, DynConf, Engine, Keypoint, Mask, Mbr, Ops, Options, Polygon, Prob, Processor, Task, Ts, + Version, Xs, Ys, Y, }; -#[derive(Debug)] +#[derive(Debug, Builder)] pub struct YOLO { - engine: OrtEngine, + engine: Engine, + height: usize, + width: usize, + batch: usize, + layout: YOLOPredsFormat, + task: Task, + version: Option, + names: Vec, + names_kpt: Vec, nc: usize, nk: usize, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, confs: DynConf, kconfs: DynConf, iou: f32, - names: Vec, - names_kpt: Vec, - task: YOLOTask, - layout: YOLOPreds, find_contours: bool, - version: Option, - classes_excluded: Vec, - classes_retained: Vec, + processor: Processor, + ts: Ts, + spec: String, + classes_excluded: Vec, + classes_retained: Vec, } -impl Vision for YOLO { - type Input = DynamicImage; +impl TryFrom for YOLO { + type Error = anyhow::Error; - fn new(options: Options) -> Result { - let span = tracing::span!(tracing::Level::INFO, "YOLO-new"); - let _guard = span.enter(); + fn try_from(options: Options) -> Result { + Self::new(options) + } +} - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), +impl YOLO { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + engine.spec().to_owned(), ); - - // YOLO Task - let task = options - .yolo_task - .or(engine.try_fetch("task").and_then(|x| match x.as_str() { - "classify" => Some(YOLOTask::Classify), - "detect" => Some(YOLOTask::Detect), - "pose" => Some(YOLOTask::Pose), - "segment" => Some(YOLOTask::Segment), - "obb" => Some(YOLOTask::Obb), - s => { - tracing::error!("YOLO Task: {s:?} is unsupported"); - None - } - })); - - // YOLO Outputs Format - let (version, layout) = match options.yolo_version { - Some(ver) => match &task { - None => anyhow::bail!("No clear YOLO Task specified for Version: {ver:?}."), - Some(task) => match task { - YOLOTask::Classify => match ver { - YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)), - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()), - x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") - } - YOLOTask::Detect => match ver { - YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()), - YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()), - YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)), - YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)), - } - YOLOTask::Pose => match ver { - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()), - x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let task: Option = match options.model_task { + Some(task) => Some(task), + None => match engine.try_fetch("task") { + Some(x) => match x.as_str() { + "classify" => Some(Task::ImageClassification), + "detect" => Some(Task::ObjectDetection), + "pose" => Some(Task::KeypointsDetection), + "segment" => Some(Task::InstanceSegmentation), + "obb" => Some(Task::OrientedObjectDetection), + x => { + error!("Unsupported YOLO Task: {}", x); + None } - YOLOTask::Segment => match ver { - YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()), - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") - } - YOLOTask::Obb => match ver { - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") + }, + None => None, + }, + }; + + // Task & layout + let version = options.model_version; + let (layout, task) = match &options.yolo_preds_format { + // customized + Some(layout) => { + // check task + let task_parsed = layout.task(); + let task = match task { + Some(task) => { + if task_parsed != task { + anyhow::bail!( + "Task specified: {:?} is inconsistent with parsed from yolo_preds_format: {:?}", + task, + task_parsed + ); + } + task_parsed } - } - } - None => match options.yolo_preds { - None => anyhow::bail!("No clear YOLO version or YOLO Format specified."), - Some(fmt) => (None, fmt) + None => task_parsed, + }; + + (layout.clone(), task) } - }; - let task = task.unwrap_or(layout.task()); + // version + task + None => match (task, version) { + (Some(task), Some(version)) => { + let layout = match (task, version) { + (Task::ImageClassification, Version(5, 0)) => { + YOLOPredsFormat::n_clss().apply_softmax(true) + } + (Task::ImageClassification, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_clss() + } + (Task::ObjectDetection, Version(5, 0) | Version(6, 0) | Version(7, 0)) => { + YOLOPredsFormat::n_a_cxcywh_confclss() + } + (Task::ObjectDetection, Version(8, 0) | Version(9, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_a() + } + (Task::ObjectDetection, Version(10, 0)) => { + YOLOPredsFormat::n_a_xyxy_confcls().apply_nms(false) + } + (Task::KeypointsDetection, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_xycs_a() + } + (Task::InstanceSegmentation, Version(5, 0)) => { + YOLOPredsFormat::n_a_cxcywh_confclss_coefs() + } + (Task::InstanceSegmentation, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_coefs_a() + } + (Task::OrientedObjectDetection, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_r_a() + } + (task, version) => { + anyhow::bail!("Task: {:?} is unsupported for Version: {:?}. Try using `.with_yolo_preds()` for customization.", task, version) + } + }; + + (layout, task) + } + (None, Some(version)) => { + let layout = match version { + // single task, no need to specified task + Version(6, 0) | Version(7, 0) => YOLOPredsFormat::n_a_cxcywh_confclss(), + Version(9, 0) => YOLOPredsFormat::n_cxcywh_clss_a(), + Version(10, 0) => YOLOPredsFormat::n_a_xyxy_confcls().apply_nms(false), + _ => { + anyhow::bail!( + "No clear YOLO Task specified for Version: {:?}.", + version + ) + } + }; + + (layout, Task::ObjectDetection) + } + (Some(task), None) => { + anyhow::bail!("No clear YOLO Version specified for Task: {:?}.", task) + } + (None, None) => { + anyhow::bail!("No clear YOLO Task and Version specified.") + } + }, + }; - // Class names: user-defined.or(parsed) - let names_parsed = Self::fetch_names(&engine); - let names = match names_parsed { - Some(names_parsed) => match options.names { + // Class names + let names: Option> = match Self::fetch_names_from_onnx(&engine) { + Some(names_parsed) => match &options.class_names { Some(names) => { if names.len() == names_parsed.len() { - Some(names) + // prioritize user-defined + Some(names.clone()) } else { + // Fail to override anyhow::bail!( "The lengths of parsed class names: {} and user-defined class names: {} do not match.", names_parsed.len(), names.len(), - ); + ) } } None => Some(names_parsed), }, - None => options.names, + None => options.class_names.clone(), }; - // nc: names.len().or(options.nc) - let nc = match &names { - Some(names) => names.len(), - None => match options.nc { - Some(nc) => nc, - None => anyhow::bail!( - "Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`." - ), + // Class names & Number of class + let (nc, names) = match (options.nc(), names) { + (_, Some(names)) => (names.len(), names.to_vec()), + (Some(nc), None) => (nc, Self::n2s(nc)), + (None, None) => { + anyhow::bail!( + "Neither class names nor the number of classes were specified. \ + \nConsider specify them with `Options::default().with_nc()` or `Options::default().with_class_names()`" + ); } }; - // Class names - let names = match names { - None => Self::n2s(nc), - Some(names) => names, - }; - - // Keypoint names & nk - let (nk, names_kpt) = match Self::fetch_kpts(&engine) { - None => (0, vec![]), - Some(nk) => match options.names2 { - Some(names) => { - if names.len() == nk { - (nk, names) - } else { + // Keypoint names & Number of keypoints + let (nk, names_kpt) = if let Task::KeypointsDetection = task { + let nk = Self::fetch_nk_from_onnx(&engine).or(options.nk()); + match (&options.keypoint_names, nk) { + (Some(names), Some(nk)) => { + if names.len() != nk { anyhow::bail!( - "The lengths of user-defined keypoint names: {} and nk: {} do not match.", + "The lengths of user-defined keypoint names: {} and nk parsed: {} do not match.", names.len(), nk, ); } + (nk, names.clone()) } - None => (nk, Self::n2s(nk)), - }, + (Some(names), None) => (names.len(), names.clone()), + (None, Some(nk)) => (nk, Self::n2s(nk)), + (None, None) => anyhow::bail!( + "Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \ + \nConsider specify them with `Options::default().with_nk()` or `Options::default().with_keypoint_names()`" + ), + } + } else { + (0, vec![]) }; - // Confs & Iou - let confs = DynConf::new(&options.confs, nc); - let kconfs = DynConf::new(&options.kconfs, nk); - let iou = options.iou.unwrap_or(0.45); - - // Classes excluded and retained - let classes_excluded = options.classes_excluded; - let classes_retained = options.classes_retained; - - // Summary - tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); - - // dry run - engine.dry_run()?; + // Attributes + let confs = DynConf::new(options.class_confs(), nc); + let kconfs = DynConf::new(options.keypoint_confs(), nk); + let iou = options.iou().unwrap_or(0.45); + let classes_excluded = options.classes_excluded().to_vec(); + let classes_retained = options.classes_retained().to_vec(); + let find_contours = options.find_contours(); + let mut info = format!( + "YOLO Version: {}, Task: {:?}, Category Count: {}, Keypoint Count: {}", + version.map_or("Unknown".into(), |x| x.to_string()), + task, + nc, + nk, + ); + if !classes_excluded.is_empty() { + info = format!("{}, classes_excluded: {:?}", info, classes_excluded); + } + if !classes_retained.is_empty() { + info = format!("{}, classes_retained: {:?}", info, classes_retained); + } + info!("{}", info); Ok(Self { engine, - confs, - kconfs, - iou, - nc, - nk, height, width, batch, task, + version, + spec, + layout, names, names_kpt, - layout, - version, - find_contours: options.find_contours, + confs, + kconfs, + iou, + nc, + nk, + find_contours, classes_excluded, classes_retained, + processor, + ts, }) } - fn preprocess(&self, xs: &[Self::Input]) -> Result { - let xs_ = match self.task { - YOLOTask::Classify => { - X::resize(xs, self.height() as u32, self.width() as u32, "Bilinear")? - .normalize(0., 255.)? - .nhwc2nchw()? - } - _ => X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?, - }; - Ok(Xs::from(xs_)) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) } fn inference(&mut self, xs: Xs) -> Result { self.engine.run(xs) } - fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> Result> { + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&self, xs: Xs) -> Result { let protos = if xs.len() == 2 { Some(&xs[1]) } else { None }; let ys: Vec = xs[0] .axis_iter(Axis(0)) @@ -227,7 +299,7 @@ impl Vision for YOLO { .filter_map(|(idx, preds)| { let mut y = Y::default(); - // parse preditions + // Parse predictions let ( slice_bboxes, slice_id, @@ -238,8 +310,8 @@ impl Vision for YOLO { slice_radians, ) = self.layout.parse_preds(preds, self.nc); - // Classifcation - if let YOLOTask::Classify = self.task { + // ImageClassifcation + if let Task::ImageClassification = self.task { let x = if self.layout.apply_softmax { let exps = slice_clss.mapv(|x| x.exp()); let stds = exps.sum_axis(Axis(0)); @@ -247,17 +319,16 @@ impl Vision for YOLO { } else { slice_clss.into_owned() }; - let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0); - probs = probs + let probs = Prob::default() + .with_probs(&x.into_raw_vec_and_offset().0) .with_names(&self.names.iter().map(|x| x.as_str()).collect::>()); - return Some(y.with_probs(&probs)); + return Some(y.with_probs(probs)); } - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); + // Original image size + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; // Other tasks let (y_bboxes, y_mbrs) = slice_bboxes? @@ -284,19 +355,21 @@ impl Vision for YOLO { } }; - // filtering by class id + // filter out class id if !self.classes_excluded.is_empty() - && self.classes_excluded.contains(&(class_id as isize)) + && self.classes_excluded.contains(&class_id) { return None; } + + // filter by class id if !self.classes_retained.is_empty() - && !self.classes_retained.contains(&(class_id as isize)) + && !self.classes_retained.contains(&class_id) { return None; } - // filtering by conf + // filter by conf if confidence < self.confs[class_id] { return None; } @@ -354,8 +427,7 @@ impl Vision for YOLO { (h, w, radians + std::f32::consts::PI / 2.) }; let radians = radians % std::f32::consts::PI; - - let mut mbr = Mbr::from_cxcywhr( + let mbr = Mbr::from_cxcywhr( cx as f64, cy as f64, w as f64, @@ -363,18 +435,18 @@ impl Vision for YOLO { radians as f64, ) .with_confidence(confidence) - .with_id(class_id as isize); - mbr = mbr.with_name(&self.names[class_id]); + .with_id(class_id as isize) + .with_name(&self.names[class_id]); (None, Some(mbr)) } None => { - let mut bbox = Bbox::default() + let bbox = Bbox::default() .with_xywh(x, y, w, h) .with_confidence(confidence) .with_id(class_id as isize) - .with_id_born(i as isize); - bbox = bbox.with_name(&self.names[class_id]); + .with_id_born(i as isize) + .with_name(&self.names[class_id]); (Some(bbox), None) } @@ -404,7 +476,7 @@ impl Vision for YOLO { } } - // Pose + // KeypointsDetection if let Some(pred_kpts) = slice_kpts { let kpt_step = self.layout.kpt_step().unwrap_or(3); if let Some(bboxes) = y.bboxes() { @@ -421,16 +493,14 @@ impl Vision for YOLO { if kconf < self.kconfs[i] { Keypoint::default() } else { - let mut kpt = Keypoint::default() + Keypoint::default() .with_id(i as isize) .with_confidence(kconf) .with_xy( - kx.max(0.0f32).min(image_width), - ky.max(0.0f32).min(image_height), - ); - - kpt = kpt.with_name(&self.names_kpt[i]); - kpt + kx.max(0.0f32).min(image_width as f32), + ky.max(0.0f32).min(image_height as f32), + ) + .with_name(&self.names_kpt[i]) } }) .collect::>(); @@ -441,7 +511,7 @@ impl Vision for YOLO { } } - // Segment + // InstanceSegmentation if let Some(coefs) = slice_coefs { if let Some(bboxes) = y.bboxes() { let (y_polygons, y_masks) = bboxes @@ -533,54 +603,26 @@ impl Vision for YOLO { }) .collect(); - Ok(ys) - } -} - -impl YOLO { - pub fn batch(&self) -> usize { - self.batch.opt() - } - - pub fn width(&self) -> usize { - self.width.opt() - } - - pub fn height(&self) -> usize { - self.height.opt() + Ok(ys.into()) } - pub fn version(&self) -> Option<&YOLOVersion> { - self.version.as_ref() - } - - pub fn task(&self) -> &YOLOTask { - &self.task - } - - pub fn layout(&self) -> &YOLOPreds { - &self.layout - } - - fn fetch_names(engine: &OrtEngine) -> Option> { + fn fetch_names_from_onnx(engine: &Engine) -> Option> { // fetch class names from onnx metadata // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` - engine.try_fetch("names").map(|names| { - let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); - let mut names_ = vec![]; - for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { - names_.push(name.to_string()); - } - names_ - }) + Regex::new(r#"(['"])([-()\w '"]+)(['"])"#) + .ok()? + .captures_iter(&engine.try_fetch("names")?) + .filter_map(|caps| caps.get(2).map(|m| m.as_str().to_string())) + .collect::>() + .into() } - fn fetch_kpts(engine: &OrtEngine) -> Option { - engine.try_fetch("kpt_shape").map(|s| { - let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); - let caps = re.captures(&s).unwrap(); - caps.get(1).unwrap().as_str().parse::().unwrap() - }) + fn fetch_nk_from_onnx(engine: &Engine) -> Option { + Regex::new(r"(\d+), \d+") + .ok()? + .captures(&engine.try_fetch("kpt_shape")?) + .and_then(|caps| caps.get(1)) + .and_then(|m| m.as_str().parse::().ok()) } fn n2s(n: usize) -> Vec { diff --git a/src/models/yolo/mod.rs b/src/models/yolo/mod.rs new file mode 100644 index 0000000..a7ca057 --- /dev/null +++ b/src/models/yolo/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod r#impl; +mod preds; + +pub use preds::*; +pub use r#impl::*; diff --git a/src/models/yolo_.rs b/src/models/yolo/preds.rs similarity index 70% rename from src/models/yolo_.rs rename to src/models/yolo/preds.rs index 994dd3a..12029bd 100644 --- a/src/models/yolo_.rs +++ b/src/models/yolo/preds.rs @@ -1,107 +1,13 @@ use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; -#[derive(Debug, Clone, clap::ValueEnum)] -pub enum YOLOTask { - Classify, - Detect, - Pose, - Segment, - Obb, -} - -impl YOLOTask { - pub fn name(&self) -> String { - match self { - Self::Classify => "cls".to_string(), - Self::Detect => "det".to_string(), - Self::Pose => "pose".to_string(), - Self::Segment => "seg".to_string(), - Self::Obb => "obb".to_string(), - } - } - - pub fn name_detailed(&self) -> String { - match self { - Self::Classify => "image-classification".to_string(), - Self::Detect => "object-detection".to_string(), - Self::Pose => "pose-estimation".to_string(), - Self::Segment => "instance-segment".to_string(), - Self::Obb => "oriented-object-detection".to_string(), - } - } -} - -#[derive(Debug, Copy, Clone, clap::ValueEnum)] -pub enum YOLOVersion { - V5, - V6, - V7, - V8, - V9, - V10, - V11, - RTDETR, -} - -impl YOLOVersion { - pub fn name(&self) -> String { - match self { - Self::V5 => "v5".to_string(), - Self::V6 => "v6".to_string(), - Self::V7 => "v7".to_string(), - Self::V8 => "v8".to_string(), - Self::V9 => "v9".to_string(), - Self::V10 => "v10".to_string(), - Self::V11 => "v11".to_string(), - Self::RTDETR => "rtdetr".to_string(), - } - } -} - -#[derive(Debug, Copy, Clone, clap::ValueEnum)] -pub enum YOLOScale { - N, - T, - B, - S, - M, - L, - C, - E, - X, -} - -impl YOLOScale { - pub fn name(&self) -> String { - match self { - Self::N => "n".to_string(), - Self::T => "t".to_string(), - Self::S => "s".to_string(), - Self::B => "b".to_string(), - Self::M => "m".to_string(), - Self::L => "l".to_string(), - Self::C => "c".to_string(), - Self::E => "e".to_string(), - Self::X => "x".to_string(), - } - } -} +use crate::Task; #[derive(Debug, Clone, PartialEq)] pub enum BoxType { - /// 1 Cxcywh, - - /// 2 Cxcybr Cxcyxy, - - /// 3 Tlbr Xyxy, - - /// 4 Tlwh Xywh, - - /// 5 Tlcxcy XyCxcy, } @@ -127,7 +33,7 @@ pub enum AnchorsPosition { } #[derive(Debug, Clone, PartialEq)] -pub struct YOLOPreds { +pub struct YOLOPredsFormat { pub clss: ClssType, pub bbox: Option, pub kpts: Option, @@ -137,9 +43,11 @@ pub struct YOLOPreds { pub is_bbox_normalized: bool, pub apply_nms: bool, pub apply_softmax: bool, + // ------------------------------------------------ + // pub is_concatenated: bool, // TODO: how to tell which parts? } -impl Default for YOLOPreds { +impl Default for YOLOPredsFormat { fn default() -> Self { Self { clss: ClssType::Clss, @@ -151,11 +59,12 @@ impl Default for YOLOPreds { is_bbox_normalized: false, apply_nms: true, apply_softmax: false, + // is_concatenated: true, } } } -impl YOLOPreds { +impl YOLOPredsFormat { pub fn apply_nms(mut self, x: bool) -> Self { self.apply_nms = x; self @@ -259,16 +168,16 @@ impl YOLOPreds { } } - pub fn task(&self) -> YOLOTask { + pub fn task(&self) -> Task { match self.obb { - Some(_) => YOLOTask::Obb, + Some(_) => Task::OrientedObjectDetection, None => match self.coefs { - Some(_) => YOLOTask::Segment, + Some(_) => Task::InstanceSegmentation, None => match self.kpts { - Some(_) => YOLOTask::Pose, + Some(_) => Task::KeypointsDetection, None => match self.bbox { - Some(_) => YOLOTask::Detect, - None => YOLOTask::Classify, + Some(_) => Task::ObjectDetection, + None => Task::ImageClassification, }, }, }, @@ -318,16 +227,16 @@ impl YOLOPreds { x: ArrayBase, Dim>, nc: usize, ) -> ( - Option>, - Option>, - ArrayView, - Option>, - Option>, - Option>, - Option>, + Option>, + Option>, + ArrayView<'a, f32, IxDyn>, + Option>, + Option>, + Option>, + Option>, ) { match self.task() { - YOLOTask::Classify => (None, None, x, None, None, None, None), + Task::ImageClassification => (None, None, x, None, None, None, None), _ => { let x = if self.is_anchors_first() { x @@ -335,7 +244,7 @@ impl YOLOPreds { x.reversed_axes() }; - // get each tasks slices + // each tasks slices let (slice_bboxes, xs) = x.split_at(Axis(1), 4); let (slice_id, slice_clss, slice_confs, xs) = match self.clss { ClssType::ConfClss => { @@ -364,9 +273,9 @@ impl YOLOPreds { } }; let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - YOLOTask::Pose => (Some(xs), None, None), - YOLOTask::Segment => (None, Some(xs), None), - YOLOTask::Obb => (None, None, Some(xs)), + Task::Pose | Task::KeypointsDetection => (Some(xs), None, None), + Task::InstanceSegmentation => (None, Some(xs), None), + Task::Obb | Task::OrientedObjectDetection => (None, None, Some(xs)), _ => (None, None, None), }; diff --git a/src/models/yolop/README.md b/src/models/yolop/README.md new file mode 100644 index 0000000..b2074a7 --- /dev/null +++ b/src/models/yolop/README.md @@ -0,0 +1,9 @@ +# YOLOPv2:rocket:: Better, Faster, Stronger for Panoptic driving Perception + +## Official Repository + +The official repository can be found on: [GitHub](https://github.com/CAIC-AD/YOLOPv2) + +## Example + +Refer to the [example](../../../examples/yolop) diff --git a/src/models/yolop/config.rs b/src/models/yolop/config.rs new file mode 100644 index 0000000..6e1564e --- /dev/null +++ b/src/models/yolop/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `YOLOP` +impl crate::Options { + pub fn yolop() -> Self { + Self::default() + .with_model_name("yolop") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("Bilinear") + .with_normalize(true) + .with_class_confs(&[0.3]) + } + + pub fn yolop_v2_480x800() -> Self { + Self::yolop().with_model_file("v2-480x800.onnx") + } + + pub fn yolop_v2_736x1280() -> Self { + Self::yolop().with_model_file("v2-736x1280.onnx") + } +} diff --git a/src/models/yolop.rs b/src/models/yolop/impl.rs similarity index 69% rename from src/models/yolop.rs rename to src/models/yolop/impl.rs index 05adbba..0e734ae 100644 --- a/src/models/yolop.rs +++ b/src/models/yolop/impl.rs @@ -1,61 +1,75 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array, Axis, IxDyn}; -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, Bbox, DynConf, Engine, Ops, Options, Polygon, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct YOLOPv2 { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, confs: DynConf, iou: f32, } impl YOLOPv2 { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), ); - let confs = DynConf::new(&options.kconfs, 80); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let confs = DynConf::new(options.class_confs(), 80); let iou = options.iou.unwrap_or(0.45f32); - engine.dry_run()?; Ok(Self { engine, - confs, height, width, batch, + confs, iou, + ts, + processor, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - // pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); let (xs_da, xs_ll, xs_det) = (&xs[0], &xs[1], &xs[2]); for (idx, ((x_det, x_ll), x_da)) in xs_det @@ -64,14 +78,8 @@ impl YOLOPv2 { .zip(xs_da.axis_iter(Axis(0))) .enumerate() { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let (ratio, _, _) = Ops::scale_wh( - image_width, - image_height, - self.width() as f32, - self.height() as f32, - ); + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; // Vehicle let mut y_bboxes = Vec::new(); @@ -94,8 +102,8 @@ impl YOLOPv2 { let h = bbox[3] / ratio; let x = cx - w / 2.; let y = cy - h / 2.; - let x = x.max(0.0).min(image_width); - let y = y.max(0.0).min(image_height); + let x = x.max(0.0).min(image_width as _); + let y = y.max(0.0).min(image_height as _); y_bboxes.push( Bbox::default() .with_xywh(x, y, w, h) @@ -112,10 +120,10 @@ impl YOLOPv2 { let contours = match self.get_contours_from_mask( x_da.into_dyn(), 0.0, - self.width() as _, - self.height() as _, - image_width, - image_height, + self.width as _, + self.height as _, + image_width as _, + image_height as _, ) { Err(_) => continue, Ok(x) => x, @@ -138,10 +146,10 @@ impl YOLOPv2 { let contours = match self.get_contours_from_mask( x_ll.to_owned(), 0.5, - self.width() as _, - self.height() as _, - image_width, - image_height, + self.width as _, + self.height as _, + image_width as _, + image_height as _, ) { Err(_) => continue, Ok(x) => x, @@ -168,19 +176,7 @@ impl YOLOPv2 { .apply_nms(self.iou), ); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } fn get_contours_from_mask( diff --git a/src/models/yolop/mod.rs b/src/models/yolop/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/yolop/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/utils/colormap256.rs b/src/utils/colormap256.rs deleted file mode 100644 index c9a5c28..0000000 --- a/src/utils/colormap256.rs +++ /dev/null @@ -1,2591 +0,0 @@ -//! Some colormap: [`TURBO`], [`INFERNO`], [`PLASMA`], [`VIRIDIS`], [`MAGMA`], [`BENTCOOLWARM`], [`BLACKBODY`], [`EXTENDEDKINDLMANN`], [`KINDLMANN`], [`SMOOTHCOOLWARM`]. - -pub const TURBO: [[u8; 3]; 256] = [ - [48, 18, 59], - [50, 21, 67], - [51, 24, 74], - [52, 27, 81], - [53, 30, 88], - [54, 33, 95], - [55, 36, 102], - [56, 39, 109], - [57, 42, 115], - [58, 45, 121], - [59, 47, 128], - [60, 50, 134], - [61, 53, 139], - [62, 56, 145], - [63, 59, 151], - [63, 62, 156], - [64, 64, 162], - [65, 67, 167], - [65, 70, 172], - [66, 73, 177], - [66, 75, 181], - [67, 78, 186], - [68, 81, 191], - [68, 84, 195], - [68, 86, 199], - [69, 89, 203], - [69, 92, 207], - [69, 94, 211], - [70, 97, 214], - [70, 100, 218], - [70, 102, 221], - [70, 105, 224], - [70, 107, 227], - [71, 110, 230], - [71, 113, 233], - [71, 115, 235], - [71, 118, 238], - [71, 120, 240], - [71, 123, 242], - [70, 125, 244], - [70, 128, 246], - [70, 130, 248], - [70, 133, 250], - [70, 135, 251], - [69, 138, 252], - [69, 140, 253], - [68, 143, 254], - [67, 145, 254], - [66, 148, 255], - [65, 150, 255], - [64, 153, 255], - [62, 155, 254], - [61, 158, 254], - [59, 160, 253], - [58, 163, 252], - [56, 165, 251], - [55, 168, 250], - [53, 171, 248], - [51, 173, 247], - [49, 175, 245], - [47, 178, 244], - [46, 180, 242], - [44, 183, 240], - [42, 185, 238], - [40, 188, 235], - [39, 190, 233], - [37, 192, 231], - [35, 195, 228], - [34, 197, 226], - [32, 199, 223], - [31, 201, 221], - [30, 203, 218], - [28, 205, 216], - [27, 208, 213], - [26, 210, 210], - [26, 212, 208], - [25, 213, 205], - [24, 215, 202], - [24, 217, 200], - [24, 219, 197], - [24, 221, 194], - [24, 222, 192], - [24, 224, 189], - [25, 226, 187], - [25, 227, 185], - [26, 228, 182], - [28, 230, 180], - [29, 231, 178], - [31, 233, 175], - [32, 234, 172], - [34, 235, 170], - [37, 236, 167], - [39, 238, 164], - [42, 239, 161], - [44, 240, 158], - [47, 241, 155], - [50, 242, 152], - [53, 243, 148], - [56, 244, 145], - [60, 245, 142], - [63, 246, 138], - [67, 247, 135], - [70, 248, 132], - [74, 248, 128], - [78, 249, 125], - [82, 250, 122], - [85, 250, 118], - [89, 251, 115], - [93, 252, 111], - [97, 252, 108], - [101, 253, 105], - [105, 253, 102], - [109, 254, 98], - [113, 254, 95], - [117, 254, 92], - [121, 254, 89], - [125, 255, 86], - [128, 255, 83], - [132, 255, 81], - [136, 255, 78], - [139, 255, 75], - [143, 255, 73], - [146, 255, 71], - [150, 254, 68], - [153, 254, 66], - [156, 254, 64], - [159, 253, 63], - [161, 253, 61], - [164, 252, 60], - [167, 252, 58], - [169, 251, 57], - [172, 251, 56], - [175, 250, 55], - [177, 249, 54], - [180, 248, 54], - [183, 247, 53], - [185, 246, 53], - [188, 245, 52], - [190, 244, 52], - [193, 243, 52], - [195, 241, 52], - [198, 240, 52], - [200, 239, 52], - [203, 237, 52], - [205, 236, 52], - [208, 234, 52], - [210, 233, 53], - [212, 231, 53], - [215, 229, 53], - [217, 228, 54], - [219, 226, 54], - [221, 224, 55], - [223, 223, 55], - [225, 221, 55], - [227, 219, 56], - [229, 217, 56], - [231, 215, 57], - [233, 213, 57], - [235, 211, 57], - [236, 209, 58], - [238, 207, 58], - [239, 205, 58], - [241, 203, 58], - [242, 201, 58], - [244, 199, 58], - [245, 197, 58], - [246, 195, 58], - [247, 193, 58], - [248, 190, 57], - [249, 188, 57], - [250, 186, 57], - [251, 184, 56], - [251, 182, 55], - [252, 179, 54], - [252, 177, 54], - [253, 174, 53], - [253, 172, 52], - [254, 169, 51], - [254, 167, 50], - [254, 164, 49], - [254, 161, 48], - [254, 158, 47], - [254, 155, 45], - [254, 153, 44], - [254, 150, 43], - [254, 147, 42], - [254, 144, 41], - [253, 141, 39], - [253, 138, 38], - [252, 135, 37], - [252, 132, 35], - [251, 129, 34], - [251, 126, 33], - [250, 123, 31], - [249, 120, 30], - [249, 117, 29], - [248, 114, 28], - [247, 111, 26], - [246, 108, 25], - [245, 105, 24], - [244, 102, 23], - [243, 99, 21], - [242, 96, 20], - [241, 93, 19], - [240, 91, 18], - [239, 88, 17], - [237, 85, 16], - [236, 83, 15], - [235, 80, 14], - [234, 78, 13], - [232, 75, 12], - [231, 73, 12], - [229, 71, 11], - [228, 69, 10], - [226, 67, 10], - [225, 65, 9], - [223, 63, 8], - [221, 61, 8], - [220, 59, 7], - [218, 57, 7], - [216, 55, 6], - [214, 53, 6], - [212, 51, 5], - [210, 49, 5], - [208, 47, 5], - [206, 45, 4], - [204, 43, 4], - [202, 42, 4], - [200, 40, 3], - [197, 38, 3], - [195, 37, 3], - [193, 35, 2], - [190, 33, 2], - [188, 32, 2], - [185, 30, 2], - [183, 29, 2], - [180, 27, 1], - [178, 26, 1], - [175, 24, 1], - [172, 23, 1], - [169, 22, 1], - [167, 20, 1], - [164, 19, 1], - [161, 18, 1], - [158, 16, 1], - [155, 15, 1], - [152, 14, 1], - [149, 13, 1], - [146, 11, 1], - [142, 10, 1], - [139, 9, 2], - [136, 8, 2], - [133, 7, 2], - [129, 6, 2], - [126, 5, 2], - [122, 4, 3], -]; - -pub const INFERNO: [[u8; 3]; 256] = [ - [0, 0, 4], - [1, 0, 5], - [1, 1, 6], - [1, 1, 8], - [2, 1, 10], - [2, 2, 12], - [2, 2, 14], - [3, 2, 16], - [4, 3, 18], - [4, 3, 20], - [5, 4, 23], - [6, 4, 25], - [7, 5, 27], - [8, 5, 29], - [9, 6, 31], - [10, 7, 34], - [11, 7, 36], - [12, 8, 38], - [13, 8, 41], - [14, 9, 43], - [16, 9, 45], - [17, 10, 48], - [18, 10, 50], - [20, 11, 52], - [21, 11, 55], - [22, 11, 57], - [24, 12, 60], - [25, 12, 62], - [27, 12, 65], - [28, 12, 67], - [30, 12, 69], - [31, 12, 72], - [33, 12, 74], - [35, 12, 76], - [36, 12, 79], - [38, 12, 81], - [40, 11, 83], - [41, 11, 85], - [43, 11, 87], - [45, 11, 89], - [47, 10, 91], - [49, 10, 92], - [50, 10, 94], - [52, 10, 95], - [54, 9, 97], - [56, 9, 98], - [57, 9, 99], - [59, 9, 100], - [61, 9, 101], - [62, 9, 102], - [64, 10, 103], - [66, 10, 104], - [68, 10, 104], - [69, 10, 105], - [71, 11, 106], - [73, 11, 106], - [74, 12, 107], - [76, 12, 107], - [77, 13, 108], - [79, 13, 108], - [81, 14, 108], - [82, 14, 109], - [84, 15, 109], - [85, 15, 109], - [87, 16, 110], - [89, 16, 110], - [90, 17, 110], - [92, 18, 110], - [93, 18, 110], - [95, 19, 110], - [97, 19, 110], - [98, 20, 110], - [100, 21, 110], - [101, 21, 110], - [103, 22, 110], - [105, 22, 110], - [106, 23, 110], - [108, 24, 110], - [109, 24, 110], - [111, 25, 110], - [113, 25, 110], - [114, 26, 110], - [116, 26, 110], - [117, 27, 110], - [119, 28, 109], - [120, 28, 109], - [122, 29, 109], - [124, 29, 109], - [125, 30, 109], - [127, 30, 108], - [128, 31, 108], - [130, 32, 108], - [132, 32, 107], - [133, 33, 107], - [135, 33, 107], - [136, 34, 106], - [138, 34, 106], - [140, 35, 105], - [141, 35, 105], - [143, 36, 105], - [144, 37, 104], - [146, 37, 104], - [147, 38, 103], - [149, 38, 103], - [151, 39, 102], - [152, 39, 102], - [154, 40, 101], - [155, 41, 100], - [157, 41, 100], - [159, 42, 99], - [160, 42, 99], - [162, 43, 98], - [163, 44, 97], - [165, 44, 96], - [166, 45, 96], - [168, 46, 95], - [169, 46, 94], - [171, 47, 94], - [173, 48, 93], - [174, 48, 92], - [176, 49, 91], - [177, 50, 90], - [179, 50, 90], - [180, 51, 89], - [182, 52, 88], - [183, 53, 87], - [185, 53, 86], - [186, 54, 85], - [188, 55, 84], - [189, 56, 83], - [191, 57, 82], - [192, 58, 81], - [193, 58, 80], - [195, 59, 79], - [196, 60, 78], - [198, 61, 77], - [199, 62, 76], - [200, 63, 75], - [202, 64, 74], - [203, 65, 73], - [204, 66, 72], - [206, 67, 71], - [207, 68, 70], - [208, 69, 69], - [210, 70, 68], - [211, 71, 67], - [212, 72, 66], - [213, 74, 65], - [215, 75, 63], - [216, 76, 62], - [217, 77, 61], - [218, 78, 60], - [219, 80, 59], - [221, 81, 58], - [222, 82, 56], - [223, 83, 55], - [224, 85, 54], - [225, 86, 53], - [226, 87, 52], - [227, 89, 51], - [228, 90, 49], - [229, 92, 48], - [230, 93, 47], - [231, 94, 46], - [232, 96, 45], - [233, 97, 43], - [234, 99, 42], - [235, 100, 41], - [235, 102, 40], - [236, 103, 38], - [237, 105, 37], - [238, 106, 36], - [239, 108, 35], - [239, 110, 33], - [240, 111, 32], - [241, 113, 31], - [241, 115, 29], - [242, 116, 28], - [243, 118, 27], - [243, 120, 25], - [244, 121, 24], - [245, 123, 23], - [245, 125, 21], - [246, 126, 20], - [246, 128, 19], - [247, 130, 18], - [247, 132, 16], - [248, 133, 15], - [248, 135, 14], - [248, 137, 12], - [249, 139, 11], - [249, 140, 10], - [249, 142, 9], - [250, 144, 8], - [250, 146, 7], - [250, 148, 7], - [251, 150, 6], - [251, 151, 6], - [251, 153, 6], - [251, 155, 6], - [251, 157, 7], - [252, 159, 7], - [252, 161, 8], - [252, 163, 9], - [252, 165, 10], - [252, 166, 12], - [252, 168, 13], - [252, 170, 15], - [252, 172, 17], - [252, 174, 18], - [252, 176, 20], - [252, 178, 22], - [252, 180, 24], - [251, 182, 26], - [251, 184, 29], - [251, 186, 31], - [251, 188, 33], - [251, 190, 35], - [250, 192, 38], - [250, 194, 40], - [250, 196, 42], - [250, 198, 45], - [249, 199, 47], - [249, 201, 50], - [249, 203, 53], - [248, 205, 55], - [248, 207, 58], - [247, 209, 61], - [247, 211, 64], - [246, 213, 67], - [246, 215, 70], - [245, 217, 73], - [245, 219, 76], - [244, 221, 79], - [244, 223, 83], - [244, 225, 86], - [243, 227, 90], - [243, 229, 93], - [242, 230, 97], - [242, 232, 101], - [242, 234, 105], - [241, 236, 109], - [241, 237, 113], - [241, 239, 117], - [241, 241, 121], - [242, 242, 125], - [242, 244, 130], - [243, 245, 134], - [243, 246, 138], - [244, 248, 142], - [245, 249, 146], - [246, 250, 150], - [248, 251, 154], - [249, 252, 157], - [250, 253, 161], - [252, 255, 164], -]; - -pub const PLASMA: [[u8; 3]; 256] = [ - [13, 8, 135], - [16, 7, 136], - [19, 7, 137], - [22, 7, 138], - [25, 6, 140], - [27, 6, 141], - [29, 6, 142], - [32, 6, 143], - [34, 6, 144], - [36, 6, 145], - [38, 5, 145], - [40, 5, 146], - [42, 5, 147], - [44, 5, 148], - [46, 5, 149], - [47, 5, 150], - [49, 5, 151], - [51, 5, 151], - [53, 4, 152], - [55, 4, 153], - [56, 4, 154], - [58, 4, 154], - [60, 4, 155], - [62, 4, 156], - [63, 4, 156], - [65, 4, 157], - [67, 3, 158], - [68, 3, 158], - [70, 3, 159], - [72, 3, 159], - [73, 3, 160], - [75, 3, 161], - [76, 2, 161], - [78, 2, 162], - [80, 2, 162], - [81, 2, 163], - [83, 2, 163], - [85, 2, 164], - [86, 1, 164], - [88, 1, 164], - [89, 1, 165], - [91, 1, 165], - [92, 1, 166], - [94, 1, 166], - [96, 1, 166], - [97, 0, 167], - [99, 0, 167], - [100, 0, 167], - [102, 0, 167], - [103, 0, 168], - [105, 0, 168], - [106, 0, 168], - [108, 0, 168], - [110, 0, 168], - [111, 0, 168], - [113, 0, 168], - [114, 1, 168], - [116, 1, 168], - [117, 1, 168], - [119, 1, 168], - [120, 1, 168], - [122, 2, 168], - [123, 2, 168], - [125, 3, 168], - [126, 3, 168], - [128, 4, 168], - [129, 4, 167], - [131, 5, 167], - [132, 5, 167], - [134, 6, 166], - [135, 7, 166], - [136, 8, 166], - [138, 9, 165], - [139, 10, 165], - [141, 11, 165], - [142, 12, 164], - [143, 13, 164], - [145, 14, 163], - [146, 15, 163], - [148, 16, 162], - [149, 17, 161], - [150, 19, 161], - [152, 20, 160], - [153, 21, 159], - [154, 22, 159], - [156, 23, 158], - [157, 24, 157], - [158, 25, 157], - [160, 26, 156], - [161, 27, 155], - [162, 29, 154], - [163, 30, 154], - [165, 31, 153], - [166, 32, 152], - [167, 33, 151], - [168, 34, 150], - [170, 35, 149], - [171, 36, 148], - [172, 38, 148], - [173, 39, 147], - [174, 40, 146], - [176, 41, 145], - [177, 42, 144], - [178, 43, 143], - [179, 44, 142], - [180, 46, 141], - [181, 47, 140], - [182, 48, 139], - [183, 49, 138], - [184, 50, 137], - [186, 51, 136], - [187, 52, 136], - [188, 53, 135], - [189, 55, 134], - [190, 56, 133], - [191, 57, 132], - [192, 58, 131], - [193, 59, 130], - [194, 60, 129], - [195, 61, 128], - [196, 62, 127], - [197, 64, 126], - [198, 65, 125], - [199, 66, 124], - [200, 67, 123], - [201, 68, 122], - [202, 69, 122], - [203, 70, 121], - [204, 71, 120], - [204, 73, 119], - [205, 74, 118], - [206, 75, 117], - [207, 76, 116], - [208, 77, 115], - [209, 78, 114], - [210, 79, 113], - [211, 81, 113], - [212, 82, 112], - [213, 83, 111], - [213, 84, 110], - [214, 85, 109], - [215, 86, 108], - [216, 87, 107], - [217, 88, 106], - [218, 90, 106], - [218, 91, 105], - [219, 92, 104], - [220, 93, 103], - [221, 94, 102], - [222, 95, 101], - [222, 97, 100], - [223, 98, 99], - [224, 99, 99], - [225, 100, 98], - [226, 101, 97], - [226, 102, 96], - [227, 104, 95], - [228, 105, 94], - [229, 106, 93], - [229, 107, 93], - [230, 108, 92], - [231, 110, 91], - [231, 111, 90], - [232, 112, 89], - [233, 113, 88], - [233, 114, 87], - [234, 116, 87], - [235, 117, 86], - [235, 118, 85], - [236, 119, 84], - [237, 121, 83], - [237, 122, 82], - [238, 123, 81], - [239, 124, 81], - [239, 126, 80], - [240, 127, 79], - [240, 128, 78], - [241, 129, 77], - [241, 131, 76], - [242, 132, 75], - [243, 133, 75], - [243, 135, 74], - [244, 136, 73], - [244, 137, 72], - [245, 139, 71], - [245, 140, 70], - [246, 141, 69], - [246, 143, 68], - [247, 144, 68], - [247, 145, 67], - [247, 147, 66], - [248, 148, 65], - [248, 149, 64], - [249, 151, 63], - [249, 152, 62], - [249, 154, 62], - [250, 155, 61], - [250, 156, 60], - [250, 158, 59], - [251, 159, 58], - [251, 161, 57], - [251, 162, 56], - [252, 163, 56], - [252, 165, 55], - [252, 166, 54], - [252, 168, 53], - [252, 169, 52], - [253, 171, 51], - [253, 172, 51], - [253, 174, 50], - [253, 175, 49], - [253, 177, 48], - [253, 178, 47], - [253, 180, 47], - [253, 181, 46], - [254, 183, 45], - [254, 184, 44], - [254, 186, 44], - [254, 187, 43], - [254, 189, 42], - [254, 190, 42], - [254, 192, 41], - [253, 194, 41], - [253, 195, 40], - [253, 197, 39], - [253, 198, 39], - [253, 200, 39], - [253, 202, 38], - [253, 203, 38], - [252, 205, 37], - [252, 206, 37], - [252, 208, 37], - [252, 210, 37], - [251, 211, 36], - [251, 213, 36], - [251, 215, 36], - [250, 216, 36], - [250, 218, 36], - [249, 220, 36], - [249, 221, 37], - [248, 223, 37], - [248, 225, 37], - [247, 226, 37], - [247, 228, 37], - [246, 230, 38], - [246, 232, 38], - [245, 233, 38], - [245, 235, 39], - [244, 237, 39], - [243, 238, 39], - [243, 240, 39], - [242, 242, 39], - [241, 244, 38], - [241, 245, 37], - [240, 247, 36], - [240, 249, 33], -]; - -pub const VIRIDIS: [[u8; 3]; 256] = [ - [68, 1, 84], - [68, 2, 86], - [69, 4, 87], - [69, 5, 89], - [70, 7, 90], - [70, 8, 92], - [70, 10, 93], - [70, 11, 94], - [71, 13, 96], - [71, 14, 97], - [71, 16, 99], - [71, 17, 100], - [71, 19, 101], - [72, 20, 103], - [72, 22, 104], - [72, 23, 105], - [72, 24, 106], - [72, 26, 108], - [72, 27, 109], - [72, 28, 110], - [72, 29, 111], - [72, 31, 112], - [72, 32, 113], - [72, 33, 115], - [72, 35, 116], - [72, 36, 117], - [72, 37, 118], - [72, 38, 119], - [72, 40, 120], - [72, 41, 121], - [71, 42, 122], - [71, 44, 122], - [71, 45, 123], - [71, 46, 124], - [71, 47, 125], - [70, 48, 126], - [70, 50, 126], - [70, 51, 127], - [70, 52, 128], - [69, 53, 129], - [69, 55, 129], - [69, 56, 130], - [68, 57, 131], - [68, 58, 131], - [68, 59, 132], - [67, 61, 132], - [67, 62, 133], - [66, 63, 133], - [66, 64, 134], - [66, 65, 134], - [65, 66, 135], - [65, 68, 135], - [64, 69, 136], - [64, 70, 136], - [63, 71, 136], - [63, 72, 137], - [62, 73, 137], - [62, 74, 137], - [62, 76, 138], - [61, 77, 138], - [61, 78, 138], - [60, 79, 138], - [60, 80, 139], - [59, 81, 139], - [59, 82, 139], - [58, 83, 139], - [58, 84, 140], - [57, 85, 140], - [57, 86, 140], - [56, 88, 140], - [56, 89, 140], - [55, 90, 140], - [55, 91, 141], - [54, 92, 141], - [54, 93, 141], - [53, 94, 141], - [53, 95, 141], - [52, 96, 141], - [52, 97, 141], - [51, 98, 141], - [51, 99, 141], - [50, 100, 142], - [50, 101, 142], - [49, 102, 142], - [49, 103, 142], - [49, 104, 142], - [48, 105, 142], - [48, 106, 142], - [47, 107, 142], - [47, 108, 142], - [46, 109, 142], - [46, 110, 142], - [46, 111, 142], - [45, 112, 142], - [45, 113, 142], - [44, 113, 142], - [44, 114, 142], - [44, 115, 142], - [43, 116, 142], - [43, 117, 142], - [42, 118, 142], - [42, 119, 142], - [42, 120, 142], - [41, 121, 142], - [41, 122, 142], - [41, 123, 142], - [40, 124, 142], - [40, 125, 142], - [39, 126, 142], - [39, 127, 142], - [39, 128, 142], - [38, 129, 142], - [38, 130, 142], - [38, 130, 142], - [37, 131, 142], - [37, 132, 142], - [37, 133, 142], - [36, 134, 142], - [36, 135, 142], - [35, 136, 142], - [35, 137, 142], - [35, 138, 141], - [34, 139, 141], - [34, 140, 141], - [34, 141, 141], - [33, 142, 141], - [33, 143, 141], - [33, 144, 141], - [33, 145, 140], - [32, 146, 140], - [32, 146, 140], - [32, 147, 140], - [31, 148, 140], - [31, 149, 139], - [31, 150, 139], - [31, 151, 139], - [31, 152, 139], - [31, 153, 138], - [31, 154, 138], - [30, 155, 138], - [30, 156, 137], - [30, 157, 137], - [31, 158, 137], - [31, 159, 136], - [31, 160, 136], - [31, 161, 136], - [31, 161, 135], - [31, 162, 135], - [32, 163, 134], - [32, 164, 134], - [33, 165, 133], - [33, 166, 133], - [34, 167, 133], - [34, 168, 132], - [35, 169, 131], - [36, 170, 131], - [37, 171, 130], - [37, 172, 130], - [38, 173, 129], - [39, 173, 129], - [40, 174, 128], - [41, 175, 127], - [42, 176, 127], - [44, 177, 126], - [45, 178, 125], - [46, 179, 124], - [47, 180, 124], - [49, 181, 123], - [50, 182, 122], - [52, 182, 121], - [53, 183, 121], - [55, 184, 120], - [56, 185, 119], - [58, 186, 118], - [59, 187, 117], - [61, 188, 116], - [63, 188, 115], - [64, 189, 114], - [66, 190, 113], - [68, 191, 112], - [70, 192, 111], - [72, 193, 110], - [74, 193, 109], - [76, 194, 108], - [78, 195, 107], - [80, 196, 106], - [82, 197, 105], - [84, 197, 104], - [86, 198, 103], - [88, 199, 101], - [90, 200, 100], - [92, 200, 99], - [94, 201, 98], - [96, 202, 96], - [99, 203, 95], - [101, 203, 94], - [103, 204, 92], - [105, 205, 91], - [108, 205, 90], - [110, 206, 88], - [112, 207, 87], - [115, 208, 86], - [117, 208, 84], - [119, 209, 83], - [122, 209, 81], - [124, 210, 80], - [127, 211, 78], - [129, 211, 77], - [132, 212, 75], - [134, 213, 73], - [137, 213, 72], - [139, 214, 70], - [142, 214, 69], - [144, 215, 67], - [147, 215, 65], - [149, 216, 64], - [152, 216, 62], - [155, 217, 60], - [157, 217, 59], - [160, 218, 57], - [162, 218, 55], - [165, 219, 54], - [168, 219, 52], - [170, 220, 50], - [173, 220, 48], - [176, 221, 47], - [178, 221, 45], - [181, 222, 43], - [184, 222, 41], - [186, 222, 40], - [189, 223, 38], - [192, 223, 37], - [194, 223, 35], - [197, 224, 33], - [200, 224, 32], - [202, 225, 31], - [205, 225, 29], - [208, 225, 28], - [210, 226, 27], - [213, 226, 26], - [216, 226, 25], - [218, 227, 25], - [221, 227, 24], - [223, 227, 24], - [226, 228, 24], - [229, 228, 25], - [231, 228, 25], - [234, 229, 26], - [236, 229, 27], - [239, 229, 28], - [241, 229, 29], - [244, 230, 30], - [246, 230, 32], - [248, 230, 33], - [251, 231, 35], - [253, 231, 37], -]; - -pub const MAGMA: [[u8; 3]; 256] = [ - [0, 0, 4], - [1, 0, 5], - [1, 1, 6], - [1, 1, 8], - [2, 1, 9], - [2, 2, 11], - [2, 2, 13], - [3, 3, 15], - [3, 3, 18], - [4, 4, 20], - [5, 4, 22], - [6, 5, 24], - [6, 5, 26], - [7, 6, 28], - [8, 7, 30], - [9, 7, 32], - [10, 8, 34], - [11, 9, 36], - [12, 9, 38], - [13, 10, 41], - [14, 11, 43], - [16, 11, 45], - [17, 12, 47], - [18, 13, 49], - [19, 13, 52], - [20, 14, 54], - [21, 14, 56], - [22, 15, 59], - [24, 15, 61], - [25, 16, 63], - [26, 16, 66], - [28, 16, 68], - [29, 17, 71], - [30, 17, 73], - [32, 17, 75], - [33, 17, 78], - [34, 17, 80], - [36, 18, 83], - [37, 18, 85], - [39, 18, 88], - [41, 17, 90], - [42, 17, 92], - [44, 17, 95], - [45, 17, 97], - [47, 17, 99], - [49, 17, 101], - [51, 16, 103], - [52, 16, 105], - [54, 16, 107], - [56, 16, 108], - [57, 15, 110], - [59, 15, 112], - [61, 15, 113], - [63, 15, 114], - [64, 15, 116], - [66, 15, 117], - [68, 15, 118], - [69, 16, 119], - [71, 16, 120], - [73, 16, 120], - [74, 16, 121], - [76, 17, 122], - [78, 17, 123], - [79, 18, 123], - [81, 18, 124], - [82, 19, 124], - [84, 19, 125], - [86, 20, 125], - [87, 21, 126], - [89, 21, 126], - [90, 22, 126], - [92, 22, 127], - [93, 23, 127], - [95, 24, 127], - [96, 24, 128], - [98, 25, 128], - [100, 26, 128], - [101, 26, 128], - [103, 27, 128], - [104, 28, 129], - [106, 28, 129], - [107, 29, 129], - [109, 29, 129], - [110, 30, 129], - [112, 31, 129], - [114, 31, 129], - [115, 32, 129], - [117, 33, 129], - [118, 33, 129], - [120, 34, 129], - [121, 34, 130], - [123, 35, 130], - [124, 35, 130], - [126, 36, 130], - [128, 37, 130], - [129, 37, 129], - [131, 38, 129], - [132, 38, 129], - [134, 39, 129], - [136, 39, 129], - [137, 40, 129], - [139, 41, 129], - [140, 41, 129], - [142, 42, 129], - [144, 42, 129], - [145, 43, 129], - [147, 43, 128], - [148, 44, 128], - [150, 44, 128], - [152, 45, 128], - [153, 45, 128], - [155, 46, 127], - [156, 46, 127], - [158, 47, 127], - [160, 47, 127], - [161, 48, 126], - [163, 48, 126], - [165, 49, 126], - [166, 49, 125], - [168, 50, 125], - [170, 51, 125], - [171, 51, 124], - [173, 52, 124], - [174, 52, 123], - [176, 53, 123], - [178, 53, 123], - [179, 54, 122], - [181, 54, 122], - [183, 55, 121], - [184, 55, 121], - [186, 56, 120], - [188, 57, 120], - [189, 57, 119], - [191, 58, 119], - [192, 58, 118], - [194, 59, 117], - [196, 60, 117], - [197, 60, 116], - [199, 61, 115], - [200, 62, 115], - [202, 62, 114], - [204, 63, 113], - [205, 64, 113], - [207, 64, 112], - [208, 65, 111], - [210, 66, 111], - [211, 67, 110], - [213, 68, 109], - [214, 69, 108], - [216, 69, 108], - [217, 70, 107], - [219, 71, 106], - [220, 72, 105], - [222, 73, 104], - [223, 74, 104], - [224, 76, 103], - [226, 77, 102], - [227, 78, 101], - [228, 79, 100], - [229, 80, 100], - [231, 82, 99], - [232, 83, 98], - [233, 84, 98], - [234, 86, 97], - [235, 87, 96], - [236, 88, 96], - [237, 90, 95], - [238, 91, 94], - [239, 93, 94], - [240, 95, 94], - [241, 96, 93], - [242, 98, 93], - [242, 100, 92], - [243, 101, 92], - [244, 103, 92], - [244, 105, 92], - [245, 107, 92], - [246, 108, 92], - [246, 110, 92], - [247, 112, 92], - [247, 114, 92], - [248, 116, 92], - [248, 118, 92], - [249, 120, 93], - [249, 121, 93], - [249, 123, 93], - [250, 125, 94], - [250, 127, 94], - [250, 129, 95], - [251, 131, 95], - [251, 133, 96], - [251, 135, 97], - [252, 137, 97], - [252, 138, 98], - [252, 140, 99], - [252, 142, 100], - [252, 144, 101], - [253, 146, 102], - [253, 148, 103], - [253, 150, 104], - [253, 152, 105], - [253, 154, 106], - [253, 155, 107], - [254, 157, 108], - [254, 159, 109], - [254, 161, 110], - [254, 163, 111], - [254, 165, 113], - [254, 167, 114], - [254, 169, 115], - [254, 170, 116], - [254, 172, 118], - [254, 174, 119], - [254, 176, 120], - [254, 178, 122], - [254, 180, 123], - [254, 182, 124], - [254, 183, 126], - [254, 185, 127], - [254, 187, 129], - [254, 189, 130], - [254, 191, 132], - [254, 193, 133], - [254, 194, 135], - [254, 196, 136], - [254, 198, 138], - [254, 200, 140], - [254, 202, 141], - [254, 204, 143], - [254, 205, 144], - [254, 207, 146], - [254, 209, 148], - [254, 211, 149], - [254, 213, 151], - [254, 215, 153], - [254, 216, 154], - [253, 218, 156], - [253, 220, 158], - [253, 222, 160], - [253, 224, 161], - [253, 226, 163], - [253, 227, 165], - [253, 229, 167], - [253, 231, 169], - [253, 233, 170], - [253, 235, 172], - [252, 236, 174], - [252, 238, 176], - [252, 240, 178], - [252, 242, 180], - [252, 244, 182], - [252, 246, 184], - [252, 247, 185], - [252, 249, 187], - [252, 251, 189], - [252, 253, 191], -]; - -pub const BENTCOOLWARM: [[u8; 3]; 256] = [ - [59, 76, 192], - [60, 78, 193], - [61, 79, 194], - [62, 80, 194], - [63, 82, 195], - [64, 83, 196], - [65, 85, 196], - [66, 86, 197], - [67, 87, 197], - [68, 89, 198], - [69, 90, 199], - [70, 91, 199], - [71, 93, 200], - [72, 94, 200], - [73, 95, 201], - [74, 97, 202], - [75, 98, 202], - [76, 100, 203], - [77, 101, 203], - [78, 102, 204], - [79, 104, 204], - [81, 105, 205], - [82, 106, 206], - [83, 108, 206], - [84, 109, 207], - [85, 110, 207], - [86, 112, 208], - [87, 113, 208], - [89, 114, 209], - [90, 116, 209], - [91, 117, 210], - [92, 118, 210], - [93, 120, 211], - [95, 121, 211], - [96, 122, 212], - [97, 124, 212], - [98, 125, 213], - [100, 126, 213], - [101, 128, 214], - [102, 129, 214], - [103, 130, 214], - [105, 131, 215], - [106, 133, 215], - [107, 134, 216], - [109, 135, 216], - [110, 137, 217], - [111, 138, 217], - [113, 139, 218], - [114, 141, 218], - [115, 142, 218], - [117, 143, 219], - [118, 145, 219], - [119, 146, 220], - [121, 147, 220], - [122, 149, 220], - [123, 150, 221], - [125, 151, 221], - [126, 152, 222], - [128, 154, 222], - [129, 155, 222], - [130, 156, 223], - [132, 158, 223], - [133, 159, 223], - [135, 160, 224], - [136, 162, 224], - [138, 163, 225], - [139, 164, 225], - [140, 165, 225], - [142, 167, 226], - [143, 168, 226], - [145, 169, 226], - [146, 171, 227], - [148, 172, 227], - [149, 173, 227], - [151, 175, 228], - [152, 176, 228], - [154, 177, 228], - [155, 178, 229], - [157, 180, 229], - [159, 181, 229], - [160, 182, 230], - [162, 184, 230], - [163, 185, 230], - [165, 186, 230], - [166, 187, 231], - [168, 189, 231], - [170, 190, 231], - [171, 191, 232], - [173, 193, 232], - [174, 194, 232], - [176, 195, 232], - [178, 196, 233], - [179, 198, 233], - [181, 199, 233], - [183, 200, 234], - [184, 202, 234], - [186, 203, 234], - [188, 204, 234], - [189, 205, 235], - [191, 207, 235], - [193, 208, 235], - [194, 209, 236], - [196, 210, 236], - [198, 212, 236], - [200, 213, 236], - [201, 214, 237], - [203, 215, 237], - [205, 217, 237], - [207, 218, 237], - [208, 219, 238], - [210, 220, 238], - [212, 222, 238], - [214, 223, 238], - [215, 224, 239], - [217, 225, 239], - [219, 227, 239], - [221, 228, 239], - [223, 229, 240], - [225, 230, 240], - [226, 232, 240], - [228, 233, 240], - [230, 234, 241], - [232, 235, 241], - [234, 237, 241], - [236, 238, 241], - [238, 239, 242], - [239, 240, 242], - [241, 242, 242], - [242, 241, 241], - [242, 240, 239], - [241, 238, 237], - [241, 237, 235], - [241, 235, 232], - [241, 234, 230], - [240, 232, 228], - [240, 231, 226], - [240, 229, 224], - [239, 228, 222], - [239, 226, 219], - [239, 225, 217], - [238, 223, 215], - [238, 222, 213], - [238, 220, 211], - [237, 219, 209], - [237, 217, 207], - [237, 216, 205], - [236, 214, 203], - [236, 213, 201], - [236, 211, 199], - [235, 210, 196], - [235, 208, 194], - [235, 207, 192], - [234, 205, 190], - [234, 204, 188], - [233, 202, 186], - [233, 201, 184], - [233, 199, 182], - [232, 197, 180], - [232, 196, 179], - [232, 194, 177], - [231, 193, 175], - [231, 191, 173], - [230, 190, 171], - [230, 188, 169], - [230, 187, 167], - [229, 185, 165], - [229, 184, 163], - [228, 182, 161], - [228, 181, 159], - [228, 179, 158], - [227, 177, 156], - [227, 176, 154], - [226, 174, 152], - [226, 173, 150], - [226, 171, 148], - [225, 170, 147], - [225, 168, 145], - [224, 167, 143], - [224, 165, 141], - [223, 163, 140], - [223, 162, 138], - [223, 160, 136], - [222, 159, 134], - [222, 157, 133], - [221, 156, 131], - [221, 154, 129], - [220, 152, 128], - [220, 151, 126], - [219, 149, 124], - [219, 148, 123], - [218, 146, 121], - [218, 144, 119], - [217, 143, 118], - [217, 141, 116], - [217, 140, 114], - [216, 138, 113], - [216, 136, 111], - [215, 135, 110], - [215, 133, 108], - [214, 132, 107], - [214, 130, 105], - [213, 128, 104], - [212, 127, 102], - [212, 125, 101], - [211, 123, 99], - [211, 122, 98], - [210, 120, 96], - [210, 119, 95], - [209, 117, 93], - [209, 115, 92], - [208, 114, 90], - [208, 112, 89], - [207, 110, 88], - [207, 108, 86], - [206, 107, 85], - [205, 105, 83], - [205, 103, 82], - [204, 102, 81], - [204, 100, 79], - [203, 98, 78], - [203, 96, 77], - [202, 95, 75], - [201, 93, 74], - [201, 91, 73], - [200, 89, 72], - [200, 87, 70], - [199, 86, 69], - [198, 84, 68], - [198, 82, 67], - [197, 80, 65], - [197, 78, 64], - [196, 76, 63], - [195, 74, 62], - [195, 72, 61], - [194, 70, 60], - [193, 68, 58], - [193, 66, 57], - [192, 64, 56], - [192, 62, 55], - [191, 60, 54], - [190, 58, 53], - [190, 55, 52], - [189, 53, 51], - [188, 50, 50], - [188, 48, 49], - [187, 45, 48], - [186, 43, 47], - [186, 40, 46], - [185, 37, 45], - [184, 34, 44], - [184, 30, 43], - [183, 27, 42], - [182, 22, 41], - [181, 18, 40], - [181, 12, 39], - [180, 4, 38], -]; - -pub const BLACKBODY: [[u8; 3]; 256] = [ - [0, 0, 0], - [3, 1, 1], - [7, 2, 1], - [10, 3, 2], - [13, 4, 2], - [16, 5, 3], - [18, 6, 3], - [20, 7, 4], - [22, 8, 4], - [24, 9, 5], - [26, 10, 5], - [27, 11, 6], - [29, 11, 6], - [30, 12, 7], - [32, 13, 8], - [33, 14, 8], - [34, 15, 9], - [36, 15, 9], - [37, 16, 10], - [38, 16, 10], - [40, 17, 11], - [41, 17, 11], - [43, 18, 12], - [44, 18, 12], - [46, 18, 13], - [47, 19, 13], - [49, 19, 14], - [50, 19, 14], - [52, 20, 15], - [54, 20, 15], - [55, 20, 15], - [57, 21, 16], - [58, 21, 16], - [60, 21, 16], - [62, 22, 17], - [63, 22, 17], - [65, 22, 17], - [66, 23, 18], - [68, 23, 18], - [70, 23, 18], - [71, 24, 19], - [73, 24, 19], - [75, 24, 19], - [76, 25, 20], - [78, 25, 20], - [80, 25, 20], - [81, 25, 20], - [83, 26, 21], - [85, 26, 21], - [86, 26, 21], - [88, 26, 21], - [90, 27, 22], - [91, 27, 22], - [93, 27, 22], - [95, 27, 22], - [97, 28, 23], - [98, 28, 23], - [100, 28, 23], - [102, 28, 23], - [104, 29, 24], - [105, 29, 24], - [107, 29, 24], - [109, 29, 24], - [111, 29, 25], - [112, 30, 25], - [114, 30, 25], - [116, 30, 25], - [118, 30, 26], - [119, 30, 26], - [121, 31, 26], - [123, 31, 26], - [125, 31, 27], - [127, 31, 27], - [128, 31, 27], - [130, 31, 27], - [132, 32, 28], - [134, 32, 28], - [136, 32, 28], - [137, 32, 28], - [139, 32, 29], - [141, 32, 29], - [143, 32, 29], - [145, 33, 29], - [147, 33, 30], - [148, 33, 30], - [150, 33, 30], - [152, 33, 31], - [154, 33, 31], - [156, 33, 31], - [158, 33, 31], - [160, 33, 32], - [161, 34, 32], - [163, 34, 32], - [165, 34, 32], - [167, 34, 33], - [169, 34, 33], - [171, 34, 33], - [173, 34, 33], - [175, 34, 34], - [177, 34, 34], - [178, 34, 34], - [179, 36, 34], - [180, 38, 34], - [181, 40, 33], - [182, 42, 33], - [183, 44, 33], - [184, 45, 33], - [185, 47, 32], - [186, 49, 32], - [187, 50, 32], - [188, 52, 31], - [189, 53, 31], - [190, 55, 31], - [191, 56, 31], - [192, 58, 30], - [193, 59, 30], - [194, 61, 30], - [195, 62, 29], - [196, 64, 29], - [197, 65, 28], - [198, 66, 28], - [199, 68, 28], - [200, 69, 27], - [201, 71, 27], - [202, 72, 26], - [203, 73, 26], - [204, 75, 25], - [205, 76, 25], - [206, 77, 24], - [207, 79, 24], - [208, 80, 23], - [209, 82, 23], - [210, 83, 22], - [211, 84, 21], - [212, 85, 21], - [213, 87, 20], - [214, 88, 19], - [215, 89, 19], - [216, 91, 18], - [217, 92, 17], - [218, 93, 16], - [219, 95, 15], - [220, 96, 14], - [221, 97, 13], - [222, 98, 12], - [223, 100, 11], - [224, 101, 9], - [225, 102, 8], - [226, 104, 7], - [227, 105, 5], - [227, 107, 5], - [227, 109, 6], - [228, 110, 7], - [228, 112, 7], - [228, 114, 8], - [228, 116, 8], - [229, 118, 9], - [229, 119, 10], - [229, 121, 10], - [229, 123, 11], - [229, 124, 12], - [230, 126, 12], - [230, 128, 13], - [230, 130, 14], - [230, 131, 14], - [230, 133, 15], - [230, 135, 15], - [231, 136, 16], - [231, 138, 17], - [231, 140, 17], - [231, 141, 18], - [231, 143, 19], - [231, 145, 19], - [231, 146, 20], - [232, 148, 21], - [232, 150, 21], - [232, 151, 22], - [232, 153, 22], - [232, 154, 23], - [232, 156, 24], - [232, 158, 24], - [232, 159, 25], - [232, 161, 26], - [232, 162, 26], - [233, 164, 27], - [233, 166, 27], - [233, 167, 28], - [233, 169, 29], - [233, 170, 29], - [233, 172, 30], - [233, 174, 30], - [233, 175, 31], - [233, 177, 32], - [233, 178, 32], - [233, 180, 33], - [233, 181, 34], - [233, 183, 34], - [233, 185, 35], - [233, 186, 35], - [233, 188, 36], - [233, 189, 37], - [233, 191, 37], - [233, 192, 38], - [233, 194, 38], - [233, 195, 39], - [233, 197, 40], - [233, 199, 40], - [233, 200, 41], - [232, 202, 42], - [232, 203, 42], - [232, 205, 43], - [232, 206, 43], - [232, 208, 44], - [232, 209, 45], - [232, 211, 45], - [232, 213, 46], - [232, 214, 47], - [232, 216, 47], - [231, 217, 48], - [231, 219, 48], - [231, 220, 49], - [231, 222, 50], - [231, 223, 50], - [231, 225, 51], - [230, 226, 52], - [230, 228, 52], - [230, 229, 53], - [231, 231, 60], - [233, 231, 69], - [234, 232, 78], - [236, 233, 87], - [237, 234, 94], - [238, 235, 102], - [240, 236, 109], - [241, 236, 117], - [242, 237, 124], - [243, 238, 131], - [245, 239, 137], - [246, 240, 144], - [247, 241, 151], - [248, 241, 158], - [249, 242, 164], - [249, 243, 171], - [250, 244, 177], - [251, 245, 184], - [252, 246, 190], - [252, 247, 197], - [253, 248, 203], - [253, 249, 210], - [254, 249, 216], - [254, 250, 223], - [254, 251, 229], - [255, 252, 236], - [255, 253, 242], - [255, 254, 249], - [255, 255, 255], -]; - -pub const EXTENDEDKINDLMANN: [[u8; 3]; 256] = [ - [0, 0, 0], - [5, 0, 4], - [9, 0, 9], - [13, 1, 13], - [16, 1, 17], - [19, 1, 21], - [22, 1, 24], - [24, 1, 27], - [26, 1, 30], - [27, 2, 34], - [28, 2, 38], - [29, 2, 42], - [29, 2, 46], - [30, 2, 50], - [30, 3, 53], - [30, 3, 57], - [30, 3, 61], - [29, 3, 65], - [29, 3, 68], - [28, 3, 72], - [27, 4, 75], - [27, 4, 79], - [26, 4, 82], - [25, 4, 85], - [24, 4, 88], - [23, 4, 92], - [22, 5, 95], - [21, 5, 98], - [20, 5, 101], - [19, 5, 103], - [18, 5, 106], - [18, 5, 109], - [17, 5, 111], - [14, 5, 115], - [8, 6, 119], - [6, 8, 120], - [6, 11, 120], - [6, 15, 119], - [6, 18, 118], - [6, 22, 116], - [5, 25, 114], - [5, 28, 112], - [5, 31, 109], - [5, 33, 107], - [5, 36, 104], - [5, 38, 101], - [5, 40, 99], - [5, 42, 96], - [5, 44, 94], - [4, 46, 91], - [4, 48, 89], - [4, 49, 87], - [4, 51, 85], - [4, 52, 83], - [4, 54, 81], - [4, 55, 79], - [4, 57, 77], - [4, 58, 76], - [4, 59, 74], - [3, 61, 73], - [3, 62, 71], - [3, 63, 70], - [3, 65, 69], - [3, 66, 67], - [3, 67, 66], - [3, 68, 65], - [3, 69, 64], - [3, 71, 63], - [3, 72, 61], - [4, 73, 60], - [4, 74, 58], - [4, 75, 56], - [4, 77, 55], - [4, 78, 53], - [4, 79, 51], - [4, 80, 49], - [4, 81, 47], - [4, 82, 45], - [4, 84, 43], - [4, 85, 41], - [4, 86, 39], - [4, 87, 36], - [4, 88, 34], - [4, 89, 31], - [4, 91, 29], - [4, 92, 26], - [5, 93, 24], - [5, 94, 21], - [5, 95, 18], - [5, 96, 15], - [5, 97, 13], - [5, 98, 10], - [5, 100, 8], - [5, 101, 6], - [6, 102, 5], - [8, 103, 5], - [10, 104, 5], - [11, 105, 5], - [13, 106, 5], - [15, 107, 5], - [17, 108, 5], - [20, 109, 5], - [22, 110, 5], - [26, 111, 5], - [29, 112, 5], - [32, 112, 5], - [36, 113, 5], - [40, 114, 5], - [43, 115, 6], - [47, 116, 6], - [51, 116, 6], - [55, 117, 6], - [59, 118, 6], - [63, 118, 6], - [67, 119, 6], - [71, 119, 6], - [76, 120, 6], - [80, 120, 6], - [84, 121, 6], - [88, 121, 6], - [92, 122, 6], - [97, 122, 6], - [101, 122, 6], - [105, 123, 6], - [109, 123, 6], - [113, 123, 6], - [118, 123, 6], - [122, 123, 6], - [126, 123, 6], - [130, 123, 6], - [135, 123, 7], - [139, 123, 7], - [144, 123, 7], - [149, 122, 7], - [154, 122, 7], - [160, 121, 8], - [165, 120, 8], - [171, 119, 8], - [177, 118, 8], - [183, 117, 9], - [189, 115, 9], - [196, 113, 9], - [202, 111, 10], - [209, 108, 10], - [216, 105, 10], - [222, 102, 11], - [229, 98, 11], - [236, 94, 11], - [243, 90, 12], - [244, 91, 27], - [245, 92, 37], - [245, 94, 46], - [245, 96, 52], - [246, 98, 58], - [246, 99, 63], - [246, 101, 67], - [246, 103, 71], - [246, 105, 74], - [246, 107, 77], - [247, 109, 79], - [247, 111, 83], - [247, 112, 87], - [247, 114, 91], - [247, 115, 96], - [248, 117, 101], - [248, 118, 106], - [248, 120, 112], - [248, 121, 118], - [249, 122, 123], - [249, 123, 129], - [249, 125, 135], - [249, 126, 141], - [249, 127, 147], - [249, 129, 153], - [249, 130, 158], - [249, 131, 164], - [249, 132, 169], - [249, 133, 175], - [249, 134, 180], - [249, 135, 186], - [249, 137, 191], - [249, 138, 196], - [249, 139, 201], - [250, 140, 206], - [250, 141, 211], - [250, 142, 216], - [250, 143, 221], - [250, 144, 225], - [250, 145, 230], - [250, 146, 235], - [250, 147, 239], - [250, 148, 243], - [250, 149, 248], - [248, 152, 250], - [245, 155, 250], - [242, 159, 250], - [239, 162, 251], - [236, 165, 251], - [234, 168, 251], - [232, 171, 251], - [230, 173, 251], - [229, 176, 251], - [227, 178, 251], - [226, 180, 251], - [225, 182, 251], - [224, 184, 252], - [224, 186, 252], - [223, 188, 252], - [223, 190, 252], - [222, 191, 252], - [222, 193, 252], - [222, 195, 252], - [222, 196, 252], - [223, 198, 252], - [223, 199, 252], - [223, 201, 252], - [224, 202, 252], - [224, 204, 253], - [225, 205, 253], - [226, 207, 253], - [226, 208, 253], - [227, 209, 253], - [228, 211, 253], - [229, 212, 253], - [229, 213, 253], - [230, 215, 253], - [231, 216, 253], - [231, 218, 253], - [231, 219, 253], - [232, 221, 253], - [232, 222, 253], - [232, 224, 253], - [232, 225, 254], - [233, 227, 254], - [233, 228, 254], - [233, 230, 254], - [233, 231, 254], - [233, 233, 254], - [233, 234, 254], - [234, 236, 254], - [234, 237, 254], - [234, 239, 254], - [235, 240, 254], - [235, 242, 254], - [236, 243, 254], - [237, 245, 254], - [237, 246, 254], - [238, 247, 254], - [239, 249, 254], - [240, 250, 254], - [242, 251, 254], - [243, 252, 254], - [245, 253, 255], - [248, 254, 255], - [251, 255, 255], - [255, 255, 255], -]; - -pub const KINDLMANN: [[u8; 3]; 256] = [ - [0, 0, 0], - [5, 0, 4], - [9, 0, 8], - [13, 1, 13], - [17, 1, 16], - [20, 1, 20], - [22, 1, 23], - [25, 1, 26], - [27, 1, 29], - [29, 2, 32], - [30, 2, 35], - [31, 2, 38], - [32, 2, 42], - [33, 2, 45], - [34, 2, 48], - [35, 2, 51], - [36, 3, 54], - [37, 3, 57], - [37, 3, 60], - [38, 3, 63], - [38, 3, 66], - [38, 3, 68], - [39, 3, 71], - [39, 4, 74], - [39, 4, 77], - [39, 4, 80], - [39, 4, 83], - [39, 4, 86], - [39, 4, 89], - [39, 4, 93], - [39, 5, 96], - [39, 5, 99], - [38, 5, 102], - [38, 5, 106], - [37, 5, 109], - [37, 5, 112], - [36, 6, 116], - [35, 6, 119], - [34, 6, 123], - [33, 6, 126], - [32, 6, 129], - [32, 6, 132], - [31, 6, 136], - [30, 7, 139], - [29, 7, 142], - [28, 7, 145], - [27, 7, 148], - [26, 7, 151], - [25, 7, 154], - [25, 7, 157], - [24, 8, 160], - [24, 8, 163], - [20, 8, 166], - [15, 8, 170], - [8, 9, 174], - [8, 12, 175], - [8, 15, 175], - [8, 19, 175], - [8, 22, 175], - [8, 25, 175], - [8, 28, 174], - [8, 32, 173], - [8, 35, 172], - [8, 38, 170], - [8, 41, 168], - [8, 43, 166], - [8, 46, 165], - [8, 49, 163], - [8, 51, 160], - [8, 53, 158], - [8, 56, 156], - [7, 58, 154], - [7, 60, 152], - [7, 62, 149], - [7, 64, 147], - [7, 66, 145], - [7, 68, 143], - [7, 70, 141], - [7, 71, 139], - [7, 73, 137], - [7, 75, 135], - [6, 76, 133], - [6, 78, 132], - [6, 80, 130], - [6, 81, 128], - [6, 83, 127], - [6, 84, 125], - [6, 86, 123], - [6, 87, 122], - [6, 88, 120], - [6, 90, 119], - [6, 91, 118], - [6, 93, 116], - [6, 94, 115], - [5, 95, 114], - [6, 96, 113], - [5, 98, 112], - [5, 99, 110], - [5, 100, 109], - [5, 102, 108], - [5, 103, 107], - [5, 104, 106], - [5, 105, 105], - [5, 107, 104], - [5, 108, 103], - [5, 109, 102], - [5, 110, 101], - [5, 112, 100], - [5, 113, 99], - [5, 114, 98], - [6, 115, 96], - [6, 117, 95], - [6, 118, 94], - [6, 119, 92], - [6, 120, 91], - [6, 122, 89], - [6, 123, 88], - [6, 124, 86], - [6, 125, 84], - [6, 127, 83], - [6, 128, 81], - [6, 129, 79], - [6, 130, 77], - [6, 132, 75], - [6, 133, 74], - [7, 134, 72], - [6, 135, 70], - [7, 137, 67], - [7, 138, 65], - [7, 139, 63], - [7, 140, 61], - [7, 142, 59], - [7, 143, 56], - [7, 144, 54], - [7, 145, 52], - [7, 147, 49], - [7, 148, 47], - [7, 149, 44], - [7, 150, 42], - [7, 151, 39], - [7, 153, 37], - [7, 154, 34], - [8, 155, 31], - [8, 156, 29], - [8, 157, 26], - [8, 159, 23], - [8, 160, 20], - [8, 161, 18], - [8, 162, 15], - [8, 163, 12], - [8, 165, 10], - [8, 166, 8], - [12, 167, 8], - [15, 168, 8], - [17, 169, 8], - [18, 170, 8], - [20, 171, 8], - [22, 173, 8], - [24, 174, 8], - [26, 175, 8], - [29, 176, 8], - [32, 177, 9], - [35, 178, 9], - [38, 179, 9], - [41, 180, 9], - [45, 181, 9], - [48, 182, 9], - [52, 183, 9], - [56, 184, 9], - [59, 185, 9], - [63, 186, 9], - [67, 187, 9], - [71, 188, 9], - [75, 189, 9], - [79, 190, 9], - [83, 190, 9], - [87, 191, 9], - [91, 192, 9], - [95, 193, 9], - [99, 193, 9], - [103, 194, 9], - [107, 195, 9], - [111, 196, 9], - [116, 196, 9], - [120, 197, 9], - [124, 198, 10], - [128, 198, 10], - [133, 199, 10], - [137, 199, 10], - [141, 200, 10], - [145, 200, 10], - [150, 201, 10], - [154, 201, 10], - [158, 202, 10], - [163, 202, 10], - [167, 202, 10], - [171, 203, 10], - [175, 203, 10], - [180, 203, 10], - [184, 204, 10], - [188, 204, 10], - [193, 204, 10], - [197, 205, 10], - [201, 205, 10], - [205, 205, 10], - [209, 205, 10], - [214, 205, 10], - [218, 205, 11], - [223, 205, 11], - [228, 205, 11], - [233, 205, 11], - [237, 205, 11], - [243, 205, 12], - [246, 204, 57], - [247, 205, 86], - [248, 205, 105], - [249, 206, 119], - [249, 207, 131], - [250, 207, 141], - [250, 208, 149], - [250, 209, 157], - [251, 210, 163], - [251, 211, 169], - [251, 212, 174], - [251, 214, 179], - [252, 215, 184], - [252, 216, 188], - [252, 217, 192], - [252, 218, 195], - [252, 220, 199], - [252, 221, 202], - [253, 222, 205], - [253, 224, 208], - [253, 225, 211], - [253, 226, 213], - [253, 227, 216], - [253, 229, 218], - [253, 230, 221], - [253, 232, 223], - [254, 233, 225], - [254, 234, 227], - [254, 236, 229], - [254, 237, 231], - [254, 238, 233], - [254, 240, 235], - [254, 241, 237], - [254, 242, 239], - [254, 244, 241], - [254, 245, 243], - [254, 247, 245], - [255, 248, 246], - [255, 249, 248], - [255, 251, 250], - [255, 252, 252], - [255, 254, 253], - [255, 255, 255], -]; - -pub const SMOOTHCOOLWARM: [[u8; 3]; 256] = [ - [59, 76, 192], - [60, 78, 194], - [61, 80, 195], - [62, 81, 197], - [63, 83, 199], - [64, 85, 200], - [66, 87, 202], - [67, 88, 203], - [68, 90, 204], - [69, 92, 206], - [70, 94, 207], - [72, 95, 209], - [73, 97, 210], - [74, 99, 212], - [75, 100, 213], - [76, 102, 214], - [78, 104, 216], - [79, 106, 217], - [80, 107, 218], - [81, 109, 219], - [83, 111, 221], - [84, 112, 222], - [85, 114, 223], - [86, 116, 224], - [88, 117, 226], - [89, 119, 227], - [90, 120, 228], - [91, 122, 229], - [93, 124, 230], - [94, 125, 231], - [95, 127, 232], - [97, 129, 233], - [98, 130, 234], - [99, 132, 235], - [101, 133, 236], - [102, 135, 237], - [103, 136, 238], - [105, 138, 239], - [106, 140, 240], - [107, 141, 240], - [109, 143, 241], - [110, 144, 242], - [111, 146, 243], - [113, 147, 244], - [114, 149, 244], - [115, 150, 245], - [117, 152, 246], - [118, 153, 246], - [119, 154, 247], - [121, 156, 248], - [122, 157, 248], - [123, 159, 249], - [125, 160, 249], - [126, 162, 250], - [128, 163, 250], - [129, 164, 251], - [130, 166, 251], - [132, 167, 252], - [133, 168, 252], - [134, 170, 252], - [136, 171, 253], - [137, 172, 253], - [139, 174, 253], - [140, 175, 254], - [141, 176, 254], - [143, 177, 254], - [144, 178, 254], - [146, 180, 254], - [147, 181, 255], - [148, 182, 255], - [150, 183, 255], - [151, 184, 255], - [153, 186, 255], - [154, 187, 255], - [155, 188, 255], - [157, 189, 255], - [158, 190, 255], - [159, 191, 255], - [161, 192, 255], - [162, 193, 255], - [163, 194, 254], - [165, 195, 254], - [166, 196, 254], - [168, 197, 254], - [169, 198, 254], - [170, 199, 253], - [172, 200, 253], - [173, 201, 253], - [174, 201, 252], - [176, 202, 252], - [177, 203, 252], - [178, 204, 251], - [180, 205, 251], - [181, 206, 250], - [182, 206, 250], - [183, 207, 249], - [185, 208, 249], - [186, 209, 248], - [187, 209, 248], - [189, 210, 247], - [190, 211, 246], - [191, 211, 246], - [192, 212, 245], - [193, 212, 244], - [195, 213, 244], - [196, 214, 243], - [197, 214, 242], - [198, 215, 241], - [200, 215, 241], - [201, 216, 240], - [202, 216, 239], - [203, 216, 238], - [204, 217, 237], - [205, 217, 236], - [206, 218, 235], - [208, 218, 234], - [209, 218, 233], - [210, 219, 232], - [211, 219, 231], - [212, 219, 230], - [213, 219, 229], - [214, 220, 228], - [215, 220, 227], - [216, 220, 226], - [217, 220, 225], - [218, 220, 224], - [219, 221, 222], - [220, 221, 221], - [221, 220, 220], - [222, 220, 219], - [223, 220, 217], - [225, 219, 216], - [226, 218, 214], - [227, 218, 213], - [228, 217, 211], - [229, 217, 210], - [229, 216, 209], - [230, 216, 207], - [231, 215, 206], - [232, 214, 204], - [233, 214, 203], - [234, 213, 201], - [235, 212, 200], - [235, 211, 198], - [236, 211, 197], - [237, 210, 195], - [238, 209, 194], - [238, 208, 192], - [239, 207, 191], - [239, 206, 189], - [240, 206, 187], - [241, 205, 186], - [241, 204, 184], - [242, 203, 183], - [242, 202, 181], - [243, 201, 180], - [243, 200, 178], - [244, 199, 177], - [244, 198, 175], - [244, 197, 173], - [245, 196, 172], - [245, 195, 170], - [245, 193, 169], - [246, 192, 167], - [246, 191, 166], - [246, 190, 164], - [246, 189, 162], - [247, 188, 161], - [247, 186, 159], - [247, 185, 158], - [247, 184, 156], - [247, 183, 155], - [247, 181, 153], - [247, 180, 151], - [247, 179, 150], - [247, 177, 148], - [247, 176, 147], - [247, 175, 145], - [247, 173, 144], - [247, 172, 142], - [247, 171, 140], - [247, 169, 139], - [247, 168, 137], - [247, 166, 136], - [246, 165, 134], - [246, 163, 133], - [246, 162, 131], - [246, 160, 129], - [245, 159, 128], - [245, 157, 126], - [245, 156, 125], - [244, 154, 123], - [244, 153, 122], - [244, 151, 120], - [243, 149, 119], - [243, 148, 117], - [242, 146, 116], - [242, 144, 114], - [241, 143, 113], - [241, 141, 111], - [240, 139, 110], - [240, 138, 108], - [239, 136, 107], - [238, 134, 105], - [238, 133, 104], - [237, 131, 102], - [237, 129, 101], - [236, 127, 99], - [235, 125, 98], - [234, 124, 96], - [234, 122, 95], - [233, 120, 93], - [232, 118, 92], - [231, 116, 91], - [230, 114, 89], - [229, 112, 88], - [229, 111, 86], - [228, 109, 85], - [227, 107, 84], - [226, 105, 82], - [225, 103, 81], - [224, 101, 79], - [223, 99, 78], - [222, 97, 77], - [221, 95, 75], - [220, 93, 74], - [219, 91, 73], - [218, 89, 71], - [216, 86, 70], - [215, 84, 69], - [214, 82, 68], - [213, 80, 66], - [212, 78, 65], - [211, 76, 64], - [209, 73, 62], - [208, 71, 61], - [207, 69, 60], - [206, 67, 59], - [204, 64, 57], - [203, 62, 56], - [202, 59, 55], - [200, 57, 54], - [199, 54, 53], - [198, 52, 52], - [196, 49, 50], - [195, 46, 49], - [193, 43, 48], - [192, 40, 47], - [191, 37, 46], - [189, 34, 45], - [188, 30, 44], - [186, 26, 43], - [185, 22, 41], - [183, 17, 40], - [182, 11, 39], - [180, 4, 38], -]; diff --git a/src/utils/names.rs b/src/utils/names.rs deleted file mode 100644 index ea6b648..0000000 --- a/src/utils/names.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! Some constants releated with COCO dataset: [`COCO_SKELETONS_16`], [`COCO_KEYPOINTS_17`], [`COCO_CLASS_NAMES_80`] - -pub const COCO_SKELETONS_16: [(usize, usize); 16] = [ - (0, 1), - (0, 2), - (1, 3), - (2, 4), - (5, 6), - (5, 11), - (6, 12), - (11, 12), - (5, 7), - (6, 8), - (7, 9), - (8, 10), - (11, 13), - (12, 14), - (13, 15), - (14, 16), -]; - -pub const COCO_KEYPOINTS_17: [&str; 17] = [ - "nose", - "left_eye", - "right_eye", - "left_ear", - "right_ear", - "left_shoulder", - "right_shoulder", - "left_elbow", - "right_elbow", - "left_wrist", - "right_wrist", - "left_hip", - "right_hip", - "left_knee", - "right_knee", - "left_ankle", - "right_ankle", -]; - -pub const COCO_CLASS_NAMES_80: [&str; 80] = [ - "person", - "bicycle", - "car", - "motorcycle", - "airplane", - "bus", - "train", - "truck", - "boat", - "traffic light", - "fire hydrant", - "stop sign", - "parking meter", - "bench", - "bird", - "cat", - "dog", - "horse", - "sheep", - "cow", - "elephant", - "bear", - "zebra", - "giraffe", - "backpack", - "umbrella", - "handbag", - "tie", - "suitcase", - "frisbee", - "skis", - "snowboard", - "sports ball", - "kite", - "baseball bat", - "baseball glove", - "skateboard", - "surfboard", - "tennis racket", - "bottle", - "wine glass", - "cup", - "fork", - "knife", - "spoon", - "bowl", - "banana", - "apple", - "sandwich", - "orange", - "broccoli", - "carrot", - "hot dog", - "pizza", - "donut", - "cake", - "chair", - "couch", - "potted plant", - "bed", - "dining table", - "toilet", - "tv", - "laptop", - "mouse", - "remote", - "keyboard", - "cell phone", - "microwave", - "oven", - "toaster", - "sink", - "refrigerator", - "book", - "clock", - "vase", - "scissors", - "teddy bear", - "hair drier", - "toothbrush", -]; - -pub const BODY_PARTS_28: [&str; 28] = [ - "Background", - "Apparel", - "Face Neck", - "Hair", - "Left Foot", - "Left Hand", - "Left Lower Arm", - "Left Lower Leg", - "Left Shoe", - "Left Sock", - "Left Upper Arm", - "Left Upper Leg", - "Lower Clothing", - "Right Foot", - "Right Hand", - "Right Lower Arm", - "Right Lower Leg", - "Right Shoe", - "Right Sock", - "Right Upper Arm", - "Right Upper Leg", - "Torso", - "Upper Clothing", - "Lower Lip", - "Upper Lip", - "Lower Teeth", - "Upper Teeth", - "Tongue", -]; diff --git a/src/ys/bbox.rs b/src/xy/bbox.rs similarity index 83% rename from src/ys/bbox.rs rename to src/xy/bbox.rs index 5d068b8..66c30f9 100644 --- a/src/ys/bbox.rs +++ b/src/xy/bbox.rs @@ -1,19 +1,21 @@ +use aksr::Builder; + use crate::Nms; /// Bounding Box 2D. /// /// This struct represents a 2D bounding box with properties such as position, size, /// class ID, confidence score, optional name, and an ID representing the born state. -#[derive(Clone, PartialEq, PartialOrd)] +#[derive(Builder, Clone, PartialEq, PartialOrd)] pub struct Bbox { x: f32, y: f32, w: f32, h: f32, id: isize, + id_born: isize, confidence: f32, name: Option, - id_born: isize, } impl Nms for Bbox { @@ -36,9 +38,9 @@ impl Default for Bbox { w: 0., h: 0., id: -1, + id_born: -1, confidence: 0., name: None, - id_born: -1, } } } @@ -160,62 +162,6 @@ impl Bbox { self } - /// Sets the class ID of the bounding box. - /// - /// # Arguments - /// - /// * `x` - The class ID to be set. - /// - /// # Returns - /// - /// A `Bbox` instance with updated class ID. - pub fn with_id(mut self, x: isize) -> Self { - self.id = x; - self - } - - /// Sets the ID representing the born state of the bounding box. - /// - /// # Arguments - /// - /// * `x` - The ID to be set. - /// - /// # Returns - /// - /// A `Bbox` instance with updated born state ID. - pub fn with_id_born(mut self, x: isize) -> Self { - self.id_born = x; - self - } - - /// Sets the confidence score of the bounding box. - /// - /// # Arguments - /// - /// * `x` - The confidence score to be set. - /// - /// # Returns - /// - /// A `Bbox` instance with updated confidence score. - pub fn with_confidence(mut self, x: f32) -> Self { - self.confidence = x; - self - } - - /// Sets the optional name of the bounding box. - /// - /// # Arguments - /// - /// * `x` - The name to be set. - /// - /// # Returns - /// - /// A `Bbox` instance with updated name. - pub fn with_name(mut self, x: &str) -> Self { - self.name = Some(x.to_string()); - self - } - /// Returns the width of the bounding box. pub fn width(&self) -> f32 { self.w @@ -271,26 +217,6 @@ impl Bbox { (self.cx(), self.cy(), self.w, self.h) } - /// Returns the class ID of the bounding box. - pub fn id(&self) -> isize { - self.id - } - - /// Returns the born state ID of the bounding box. - pub fn id_born(&self) -> isize { - self.id_born - } - - /// Returns the optional name associated with the bounding box, if any. - pub fn name(&self) -> Option<&String> { - self.name.as_ref() - } - - // /// Returns the confidence score of the bounding box. - // pub fn confidence(&self) -> f32 { - // self.confidence - // } - /// A label string representing the bounding box, optionally including name and confidence score. pub fn label(&self, with_name: bool, with_conf: bool, decimal_places: usize) -> String { let mut label = String::new(); diff --git a/src/ys/keypoint.rs b/src/xy/keypoint.rs similarity index 92% rename from src/ys/keypoint.rs rename to src/xy/keypoint.rs index 54a43c9..e75d00a 100644 --- a/src/ys/keypoint.rs +++ b/src/xy/keypoint.rs @@ -1,7 +1,8 @@ +use aksr::Builder; use std::ops::{Add, Div, Mul, Sub}; /// Keypoint 2D. -#[derive(PartialEq, Clone)] +#[derive(Builder, PartialEq, Clone)] pub struct Keypoint { x: f32, y: f32, @@ -149,6 +150,18 @@ impl From<[f32; 2]> for Keypoint { } } +impl From<(f32, f32, isize)> for Keypoint { + fn from((x, y, id): (f32, f32, isize)) -> Self { + Self { + x, + y, + id, + confidence: 1., + ..Default::default() + } + } +} + impl From<(f32, f32, isize, f32)> for Keypoint { fn from((x, y, id, confidence): (f32, f32, isize, f32)) -> Self { Self { @@ -180,41 +193,6 @@ impl Keypoint { self } - pub fn with_confidence(mut self, x: f32) -> Self { - self.confidence = x; - self - } - - pub fn with_id(mut self, x: isize) -> Self { - self.id = x; - self - } - - pub fn with_name(mut self, x: &str) -> Self { - self.name = Some(x.to_string()); - self - } - - pub fn x(&self) -> f32 { - self.x - } - - pub fn y(&self) -> f32 { - self.y - } - - pub fn confidence(&self) -> f32 { - self.confidence - } - - pub fn id(&self) -> isize { - self.id - } - - pub fn name(&self) -> Option<&String> { - self.name.as_ref() - } - pub fn label(&self, with_name: bool, with_conf: bool, decimal_places: usize) -> String { let mut label = String::new(); if with_name { diff --git a/src/ys/mask.rs b/src/xy/mask.rs similarity index 61% rename from src/ys/mask.rs rename to src/xy/mask.rs index f3e91ea..1913e3d 100644 --- a/src/ys/mask.rs +++ b/src/xy/mask.rs @@ -1,7 +1,8 @@ +use aksr::Builder; use image::GrayImage; /// Mask: Gray Image. -#[derive(Clone, PartialEq)] +#[derive(Builder, Clone, PartialEq)] pub struct Mask { mask: GrayImage, id: isize, @@ -31,41 +32,10 @@ impl std::fmt::Debug for Mask { } impl Mask { - pub fn with_mask(mut self, x: GrayImage) -> Self { - self.mask = x; - self - } - - pub fn with_id(mut self, x: isize) -> Self { - self.id = x; - self - } - - pub fn with_name(mut self, x: &str) -> Self { - self.name = Some(x.to_string()); - self - } - - pub fn mask(&self) -> &GrayImage { - &self.mask - } - pub fn to_vec(&self) -> Vec { self.mask.to_vec() } - pub fn id(&self) -> isize { - self.id - } - - pub fn name(&self) -> Option<&String> { - self.name.as_ref() - } - - pub fn confidence(&self) -> f32 { - self.confidence - } - pub fn height(&self) -> u32 { self.mask.height() } diff --git a/src/ys/mbr.rs b/src/xy/mbr.rs similarity index 94% rename from src/ys/mbr.rs rename to src/xy/mbr.rs index 7325508..c76c9f8 100644 --- a/src/ys/mbr.rs +++ b/src/xy/mbr.rs @@ -1,9 +1,10 @@ +use aksr::Builder; use geo::{coord, line_string, Area, BooleanOps, Coord, EuclideanDistance, LineString, Polygon}; use crate::Nms; /// Minimum Bounding Rectangle. -#[derive(Clone, PartialEq)] +#[derive(Builder, Clone, PartialEq)] pub struct Mbr { ls: LineString, id: isize, @@ -91,29 +92,6 @@ impl Mbr { } } - pub fn with_id(mut self, id: isize) -> Self { - self.id = id; - self - } - - pub fn with_confidence(mut self, x: f32) -> Self { - self.confidence = x; - self - } - - pub fn with_name(mut self, x: &str) -> Self { - self.name = Some(x.to_string()); - self - } - - pub fn id(&self) -> isize { - self.id - } - - pub fn name(&self) -> Option<&String> { - self.name.as_ref() - } - pub fn label(&self, with_name: bool, with_conf: bool, decimal_places: usize) -> String { let mut label = String::new(); if with_name { diff --git a/src/ys/mod.rs b/src/xy/mod.rs similarity index 72% rename from src/ys/mod.rs rename to src/xy/mod.rs index f07d38a..626e151 100644 --- a/src/ys/mod.rs +++ b/src/xy/mod.rs @@ -1,20 +1,27 @@ mod bbox; -mod embedding; +// mod embedding; mod keypoint; mod mask; mod mbr; mod polygon; mod prob; +mod text; +mod x; +mod xs; mod y; +mod ys; pub use bbox::Bbox; -pub use embedding::Embedding; pub use keypoint::Keypoint; pub use mask::Mask; pub use mbr::Mbr; pub use polygon::Polygon; pub use prob::Prob; +pub use text::Text; +pub use x::X; +pub use xs::Xs; pub use y::Y; +pub use ys::Ys; pub trait Nms { fn iou(&self, other: &Self) -> f32; diff --git a/src/ys/polygon.rs b/src/xy/polygon.rs similarity index 90% rename from src/ys/polygon.rs rename to src/xy/polygon.rs index be20b32..0d93b93 100644 --- a/src/ys/polygon.rs +++ b/src/xy/polygon.rs @@ -1,3 +1,4 @@ +use aksr::Builder; use geo::{ coord, point, polygon, Area, BoundingRect, Centroid, ConvexHull, EuclideanLength, LineString, MinimumRotatedRect, Point, Simplify, @@ -6,7 +7,7 @@ use geo::{ use crate::{Bbox, Mbr}; /// Polygon. -#[derive(Clone, PartialEq)] +#[derive(Builder, Clone, PartialEq)] pub struct Polygon { polygon: geo::Polygon, id: isize, @@ -59,38 +60,6 @@ impl Polygon { self } - pub fn with_polygon(mut self, x: geo::Polygon) -> Self { - self.polygon = x; - self - } - - pub fn with_id(mut self, x: isize) -> Self { - self.id = x; - self - } - - pub fn with_name(mut self, x: &str) -> Self { - self.name = Some(x.to_string()); - self - } - - pub fn with_confidence(mut self, x: f32) -> Self { - self.confidence = x; - self - } - - pub fn id(&self) -> isize { - self.id - } - - pub fn name(&self) -> Option<&String> { - self.name.as_ref() - } - - pub fn confidence(&self) -> f32 { - self.confidence - } - pub fn label(&self, with_name: bool, with_conf: bool, decimal_places: usize) -> String { let mut label = String::new(); if with_name { @@ -112,10 +81,6 @@ impl Polygon { label } - pub fn polygon(&self) -> &geo::Polygon { - &self.polygon - } - pub fn is_closed(&self) -> bool { self.polygon.exterior().is_closed() } diff --git a/src/ys/prob.rs b/src/xy/prob.rs similarity index 66% rename from src/ys/prob.rs rename to src/xy/prob.rs index f2aba92..ee974ba 100644 --- a/src/ys/prob.rs +++ b/src/xy/prob.rs @@ -1,5 +1,7 @@ +use aksr::Builder; + /// Probabilities for classification. -#[derive(Clone, PartialEq, Default)] +#[derive(Builder, Clone, PartialEq, Default)] pub struct Prob { probs: Vec, names: Option>, @@ -12,25 +14,6 @@ impl std::fmt::Debug for Prob { } impl Prob { - pub fn with_names(mut self, names: &[&str]) -> Self { - let names = names.iter().map(|x| x.to_string()).collect::>(); - self.names = Some(names); - self - } - - pub fn with_probs(mut self, x: &[f32]) -> Self { - self.probs = x.to_vec(); - self - } - - pub fn probs(&self) -> &Vec { - &self.probs - } - - pub fn names(&self) -> Option<&Vec> { - self.names.as_ref() - } - pub fn topk(&self, k: usize) -> Vec<(usize, f32, Option)> { let mut probs = self .probs diff --git a/src/xy/text.rs b/src/xy/text.rs new file mode 100644 index 0000000..0c67b5f --- /dev/null +++ b/src/xy/text.rs @@ -0,0 +1,17 @@ +/// Wrapper over [`String`] +#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)] +pub struct Text(String); + +impl std::ops::Deref for Text { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl> From for Text { + fn from(x: T) -> Self { + Self(x.as_ref().to_string()) + } +} diff --git a/src/core/x.rs b/src/xy/x.rs similarity index 54% rename from src/core/x.rs rename to src/xy/x.rs index 3a245e2..6b4b482 100644 --- a/src/core/x.rs +++ b/src/xy/x.rs @@ -2,10 +2,10 @@ use anyhow::Result; use image::DynamicImage; use ndarray::{Array, Dim, IntoDimension, IxDyn, IxDynImpl}; -use crate::Ops; +use crate::{Ops, ResizeMode}; -/// Model input, wrapper over [`Array`] -#[derive(Debug, Clone, Default)] +/// Wrapper over [`Array`] +#[derive(Debug, Clone, Default, PartialEq)] pub struct X(pub Array); impl From> for X { @@ -20,6 +20,42 @@ impl From> for X { } } +impl TryFrom> for X { + type Error = anyhow::Error; + + fn try_from(values: Vec<(u32, u32)>) -> Result { + if values.is_empty() { + Ok(Self::default()) + } else { + let mut flattened: Vec = Vec::new(); + for &(a, b) in values.iter() { + flattened.push(a); + flattened.push(b); + } + let shape = (values.len(), 2); + let x = Array::from_shape_vec(shape, flattened)? + .map(|x| *x as f32) + .into_dyn(); + Ok(Self(x)) + } + } +} + +impl TryFrom>> for X { + type Error = anyhow::Error; + + fn try_from(xs: Vec>) -> Result { + if xs.is_empty() { + Ok(Self::default()) + } else { + let shape = (xs.len(), xs[0].len()); + let flattened: Vec = xs.iter().flatten().cloned().collect(); + let x = Array::from_shape_vec(shape, flattened)?.into_dyn(); + Ok(Self(x)) + } + } +} + impl std::ops::Deref for X { type Target = Array; @@ -29,6 +65,8 @@ impl std::ops::Deref for X { } impl X { + // TODO: Add some slice and index method + pub fn zeros(shape: &[usize]) -> Self { Self::from(Array::zeros(Dim(IxDynImpl::from(shape.to_vec())))) } @@ -37,11 +75,27 @@ impl X { Self::from(Array::ones(Dim(IxDynImpl::from(shape.to_vec())))) } + pub fn zeros_like(x: Self) -> Self { + Self::from(Array::zeros(x.raw_dim())) + } + + pub fn ones_like(x: Self) -> Self { + Self::from(Array::ones(x.raw_dim())) + } + + pub fn full(shape: &[usize], x: f32) -> Self { + Self::from(Array::from_elem(shape, x)) + } + + pub fn from_shape_vec(shape: &[usize], xs: Vec) -> Result { + Ok(Self::from(Array::from_shape_vec(shape, xs)?)) + } + pub fn apply(ops: &[Ops]) -> Result { let mut y = Self::default(); for op in ops { y = match op { - Ops::Resize(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?, + Ops::FitExact(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?, Ops::Letterbox(xs, h, w, filter, bg, resize_by, center) => { Self::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)? } @@ -103,6 +157,12 @@ impl X { Ok(self) } + pub fn concat(xs: &[Self], d: usize) -> Result { + let xs = xs.iter().cloned().map(|x| x.0).collect::>(); + let x = Ops::concat(&xs, d)?; + Ok(x.into()) + } + pub fn dims(&self) -> &[usize] { self.0.shape() } @@ -126,6 +186,11 @@ impl X { Ok(self) } + pub fn unsigned(mut self) -> Self { + self.0 = self.0.mapv(|x| if x < 0.0 { 0.0 } else { x }); + self + } + pub fn resize(xs: &[DynamicImage], height: u32, width: u32, filter: &str) -> Result { Ok(Self::from(Ops::resize(xs, height, width, filter)?)) } @@ -143,4 +208,47 @@ impl X { xs, height, width, filter, bg, resize_by, center, )?)) } + + #[allow(clippy::too_many_arguments)] + pub fn preprocess( + xs: &[image::DynamicImage], + image_width: u32, + image_height: u32, + resize_mode: &ResizeMode, + resizer_filter: &str, + padding_value: u8, + letterbox_center: bool, + normalize: bool, + image_std: &[f32], + image_mean: &[f32], + nchw: bool, + ) -> Result { + let mut x = match resize_mode { + ResizeMode::FitExact => X::resize(xs, image_height, image_width, resizer_filter)?, + ResizeMode::Letterbox => X::letterbox( + xs, + image_height, + image_width, + resizer_filter, + padding_value, + "auto", + letterbox_center, + )?, + _ => unimplemented!(), + }; + + if normalize { + x = x.normalize(0., 255.)?; + } + + if !image_std.is_empty() && !image_mean.is_empty() { + x = x.standardize(image_mean, image_std, 3)?; + } + + if nchw { + x = x.nhwc2nchw()?; + } + + Ok(x) + } } diff --git a/src/core/xs.rs b/src/xy/xs.rs similarity index 86% rename from src/core/xs.rs rename to src/xy/xs.rs index 8b6a11c..d54c347 100644 --- a/src/core/xs.rs +++ b/src/xy/xs.rs @@ -1,13 +1,19 @@ +use aksr::Builder; use anyhow::Result; +use image::DynamicImage; use std::collections::HashMap; use std::ops::{Deref, Index}; use crate::{string_random, X}; -#[derive(Debug, Default, Clone)] +#[derive(Builder, Debug, Default, Clone)] pub struct Xs { map: HashMap, names: Vec, + + // TODO: move to Processor + pub images: Vec>, + pub texts: Vec>, } impl From for Xs { @@ -35,6 +41,14 @@ impl Xs { } } + pub fn derive(&self) -> Self { + Self { + map: Default::default(), + names: Default::default(), + ..self.clone() + } + } + pub fn push(&mut self, value: X) { loop { let key = string_random(5); @@ -55,10 +69,6 @@ impl Xs { anyhow::bail!("Xs already contains key: {:?}", key) } } - - pub fn names(&self) -> &Vec { - &self.names - } } impl Deref for Xs { diff --git a/src/xy/y.rs b/src/xy/y.rs new file mode 100644 index 0000000..cd263d1 --- /dev/null +++ b/src/xy/y.rs @@ -0,0 +1,124 @@ +use aksr::Builder; + +use crate::{Bbox, Keypoint, Mask, Mbr, Nms, Polygon, Prob, Text, X}; + +/// Container for inference results for each image. +/// +/// This struct holds various possible outputs from an image inference process, +/// including probabilities, bounding boxes, keypoints, minimum bounding rectangles, +/// polygons, masks, text annotations, and embeddings. +/// +/// # Fields +/// +/// * `texts` - Optionally contains a vector of texts. +/// * `embedding` - Optionally contains the embedding representation. +/// * `probs` - Optionally contains the probability scores for the detected objects. +/// * `bboxes` - Optionally contains a vector of bounding boxes. +/// * `keypoints` - Optionally contains a nested vector of keypoints. +/// * `mbrs` - Optionally contains a vector of minimum bounding rectangles. +/// * `polygons` - Optionally contains a vector of polygons. +/// * `masks` - Optionally contains a vector of masks. +#[derive(Builder, Clone, PartialEq, Default)] +pub struct Y { + texts: Option>, + embedding: Option, + probs: Option, + bboxes: Option>, + keypoints: Option>>, + mbrs: Option>, + polygons: Option>, + masks: Option>, +} + +impl std::fmt::Debug for Y { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut f = f.debug_struct("Y"); + if let Some(xs) = &self.texts { + if !xs.is_empty() { + f.field("Texts", &xs); + } + } + if let Some(xs) = &self.probs { + f.field("Probs", &xs); + } + if let Some(xs) = &self.bboxes { + if !xs.is_empty() { + f.field("BBoxes", &xs); + } + } + if let Some(xs) = &self.mbrs { + if !xs.is_empty() { + f.field("OBBs", &xs); + } + } + if let Some(xs) = &self.keypoints { + if !xs.is_empty() { + f.field("Kpts", &xs); + } + } + if let Some(xs) = &self.polygons { + if !xs.is_empty() { + f.field("Polys", &xs); + } + } + if let Some(xs) = &self.masks { + if !xs.is_empty() { + f.field("Masks", &xs); + } + } + if let Some(x) = &self.embedding { + f.field("Embedding", &x); + } + f.finish() + } +} + +impl Y { + pub fn hbbs(&self) -> Option<&[Bbox]> { + self.bboxes.as_deref() + } + + pub fn obbs(&self) -> Option<&[Mbr]> { + self.mbrs.as_deref() + } + + pub fn apply_nms(mut self, iou_threshold: f32) -> Self { + match &mut self.bboxes { + None => match &mut self.mbrs { + None => self, + Some(ref mut mbrs) => { + Self::nms(mbrs, iou_threshold); + self + } + }, + Some(ref mut bboxes) => { + Self::nms(bboxes, iou_threshold); + self + } + } + } + + pub fn nms(xxx: &mut Vec, iou_threshold: f32) { + xxx.sort_by(|b1, b2| { + b2.confidence() + .partial_cmp(&b1.confidence()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + let mut current_index = 0; + for index in 0..xxx.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = xxx[prev_index].iou(&xxx[index]); + if iou > iou_threshold { + drop = true; + break; + } + } + if !drop { + xxx.swap(current_index, index); + current_index += 1; + } + } + xxx.truncate(current_index); + } +} diff --git a/src/xy/ys.rs b/src/xy/ys.rs new file mode 100644 index 0000000..b4303b1 --- /dev/null +++ b/src/xy/ys.rs @@ -0,0 +1,19 @@ +use crate::Y; + +/// Wrapper over `Vec` +#[derive(aksr::Builder, Default, Debug)] +pub struct Ys(pub Vec); + +impl From> for Ys { + fn from(xs: Vec) -> Self { + Self(xs) + } +} + +impl std::ops::Deref for Ys { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/ys/embedding.rs b/src/ys/embedding.rs deleted file mode 100644 index daaf63d..0000000 --- a/src/ys/embedding.rs +++ /dev/null @@ -1,53 +0,0 @@ -use anyhow::Result; -use ndarray::{Array, Axis, Ix2, IxDyn}; - -use crate::X; - -/// Embedding for image or text. -#[derive(Clone, PartialEq, Default)] -pub struct Embedding(Array); - -impl std::fmt::Debug for Embedding { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("").field("Shape", &self.0.shape()).finish() - } -} - -impl From for Embedding { - fn from(x: X) -> Self { - Self(x.0) - } -} - -impl Embedding { - pub fn new(x: Array) -> Self { - Self(x) - } - - pub fn with_embedding(mut self, x: Array) -> Self { - self.0 = x; - self - } - - pub fn data(&self) -> &Array { - &self.0 - } - - pub fn norm(mut self) -> Self { - let std_ = self.0.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt); - self.0 = self.0 / std_; - self - } - - pub fn dot2(&self, other: &Embedding) -> Result>> { - // (m, ndim) * (n, ndim).t => (m, n) - let query = self.0.to_owned().into_dimensionality::()?; - let gallery = other.0.to_owned().into_dimensionality::()?; - let matrix = query.dot(&gallery.t()); - let exps = matrix.mapv(|x| x.exp()); - let stds = exps.sum_axis(Axis(1)); - let matrix = exps / stds.insert_axis(Axis(1)); - let matrix: Vec> = matrix.axis_iter(Axis(0)).map(|row| row.to_vec()).collect(); - Ok(matrix) - } -} diff --git a/src/ys/y.rs b/src/ys/y.rs deleted file mode 100644 index d7b8827..0000000 --- a/src/ys/y.rs +++ /dev/null @@ -1,298 +0,0 @@ -use crate::{Bbox, Embedding, Keypoint, Mask, Mbr, Nms, Polygon, Prob}; - -/// Container for inference results for each image. -/// -/// This struct holds various possible outputs from an image inference process, -/// including probabilities, bounding boxes, keypoints, minimum bounding rectangles, -/// polygons, masks, text annotations, and embeddings. -/// -/// # Fields -/// -/// * `probs` - Optionally contains the probability scores for the detected objects. -/// * `bboxes` - Optionally contains a vector of bounding boxes. -/// * `keypoints` - Optionally contains a nested vector of keypoints. -/// * `mbrs` - Optionally contains a vector of minimum bounding rectangles. -/// * `polygons` - Optionally contains a vector of polygons. -/// * `texts` - Optionally contains a vector of text annotations. -/// * `masks` - Optionally contains a vector of masks. -/// * `embedding` - Optionally contains the embedding representation. -#[derive(Clone, PartialEq, Default)] -pub struct Y { - probs: Option, - bboxes: Option>, - keypoints: Option>>, - mbrs: Option>, - polygons: Option>, - texts: Option>, - masks: Option>, - embedding: Option, -} - -impl std::fmt::Debug for Y { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut f = f.debug_struct("Y"); - if let Some(x) = &self.texts { - if !x.is_empty() { - f.field("Texts", &x); - } - } - if let Some(x) = &self.probs { - f.field("Probabilities", &x); - } - if let Some(x) = &self.bboxes { - if !x.is_empty() { - f.field("BoundingBoxes", &x); - } - } - if let Some(x) = &self.mbrs { - if !x.is_empty() { - f.field("MinimumBoundingRectangles", &x); - } - } - if let Some(x) = &self.keypoints { - if !x.is_empty() { - f.field("Keypoints", &x); - } - } - if let Some(x) = &self.polygons { - if !x.is_empty() { - f.field("Polygons", &x); - } - } - if let Some(x) = &self.masks { - if !x.is_empty() { - f.field("Masks", &x); - } - } - if let Some(x) = &self.embedding { - f.field("Embedding", &x); - } - f.finish() - } -} - -impl Y { - /// Sets the `masks` field with the provided vector of masks. - /// - /// # Arguments - /// - /// * `masks` - A slice of `Mask` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new masks set. - pub fn with_masks(mut self, masks: &[Mask]) -> Self { - self.masks = Some(masks.to_vec()); - self - } - - /// Sets the `probs` field with the provided probability scores. - /// - /// # Arguments - /// - /// * `probs` - A reference to a `Prob` instance to be cloned and set in the struct. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new probabilities set. - pub fn with_probs(mut self, probs: &Prob) -> Self { - self.probs = Some(probs.clone()); - self - } - - /// Sets the `texts` field with the provided vector of text annotations. - /// - /// # Arguments - /// - /// * `texts` - A slice of `String` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new texts set. - pub fn with_texts(mut self, texts: &[String]) -> Self { - self.texts = Some(texts.to_vec()); - self - } - - /// Sets the `mbrs` field with the provided vector of minimum bounding rectangles. - /// - /// # Arguments - /// - /// * `mbrs` - A slice of `Mbr` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new minimum bounding rectangles set. - pub fn with_mbrs(mut self, mbrs: &[Mbr]) -> Self { - self.mbrs = Some(mbrs.to_vec()); - self - } - - /// Sets the `bboxes` field with the provided vector of bounding boxes. - /// - /// # Arguments - /// - /// * `bboxes` - A slice of `Bbox` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new bounding boxes set. - pub fn with_bboxes(mut self, bboxes: &[Bbox]) -> Self { - self.bboxes = Some(bboxes.to_vec()); - self - } - - /// Sets the `embedding` field with the provided embedding. - /// - /// # Arguments - /// - /// * `embedding` - A reference to an `Embedding` instance to be cloned and set in the struct. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new embedding set. - pub fn with_embedding(mut self, embedding: &Embedding) -> Self { - self.embedding = Some(embedding.clone()); - self - } - - /// Sets the `keypoints` field with the provided nested vector of keypoints. - /// - /// # Arguments - /// - /// * `keypoints` - A slice of vectors of `Keypoint` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new keypoints set. - pub fn with_keypoints(mut self, keypoints: &[Vec]) -> Self { - self.keypoints = Some(keypoints.to_vec()); - self - } - - /// Sets the `polygons` field with the provided vector of polygons. - /// - /// # Arguments - /// - /// * `polygons` - A slice of `Polygon` to be set. - /// - /// # Returns - /// - /// * `Self` - The updated struct instance with the new polygons set. - pub fn with_polygons(mut self, polygons: &[Polygon]) -> Self { - self.polygons = Some(polygons.to_vec()); - self - } - - /// Returns a reference to the `masks` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>` - A reference to the vector of masks, or `None` if it is not set. - pub fn masks(&self) -> Option<&Vec> { - self.masks.as_ref() - } - - /// Returns a reference to the `probs` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Prob>` - A reference to the probabilities, or `None` if it is not set. - pub fn probs(&self) -> Option<&Prob> { - self.probs.as_ref() - } - - /// Returns a reference to the `keypoints` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>>` - A reference to the nested vector of keypoints, or `None` if it is not set. - pub fn keypoints(&self) -> Option<&Vec>> { - self.keypoints.as_ref() - } - - /// Returns a reference to the `polygons` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>` - A reference to the vector of polygons, or `None` if it is not set. - pub fn polygons(&self) -> Option<&Vec> { - self.polygons.as_ref() - } - - /// Returns a reference to the `bboxes` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>` - A reference to the vector of bounding boxes, or `None` if it is not set. - pub fn bboxes(&self) -> Option<&Vec> { - self.bboxes.as_ref() - } - - /// Returns a reference to the `mbrs` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>` - A reference to the vector of minimum bounding rectangles, or `None` if it is not set. - pub fn mbrs(&self) -> Option<&Vec> { - self.mbrs.as_ref() - } - - /// Returns a reference to the `texts` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Vec>` - A reference to the vector of texts, or `None` if it is not set. - pub fn texts(&self) -> Option<&Vec> { - self.texts.as_ref() - } - - /// Returns a reference to the `embedding` field, if it exists. - /// - /// # Returns - /// - /// * `Option<&Embedding>` - A reference to the embedding, or `None` if it is not set. - pub fn embedding(&self) -> Option<&Embedding> { - self.embedding.as_ref() - } - - pub fn apply_nms(mut self, iou_threshold: f32) -> Self { - match &mut self.bboxes { - None => match &mut self.mbrs { - None => self, - Some(ref mut mbrs) => { - Self::nms(mbrs, iou_threshold); - self - } - }, - Some(ref mut bboxes) => { - Self::nms(bboxes, iou_threshold); - self - } - } - } - - pub fn nms(xxx: &mut Vec, iou_threshold: f32) { - xxx.sort_by(|b1, b2| { - b2.confidence() - .partial_cmp(&b1.confidence()) - .unwrap_or(std::cmp::Ordering::Equal) - }); - let mut current_index = 0; - for index in 0..xxx.len() { - let mut drop = false; - for prev_index in 0..current_index { - let iou = xxx[prev_index].iou(&xxx[index]); - if iou > iou_threshold { - drop = true; - break; - } - } - if !drop { - xxx.swap(current_index, index); - current_index += 1; - } - } - xxx.truncate(current_index); - } -}