diff --git a/Cargo.toml b/Cargo.toml index 75b9e0e..b1b14ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fast-stats" -version = "0.1.0" +version = "0.0.4" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,6 +8,9 @@ edition = "2021" name = "fast_stats" crate-type = ["cdylib"] +[package.metadata.maturin] +name = "fast_stats._fast_stats_ext" + [dependencies] pyo3 = { version = "0.16.3", features = ["extension-module"] } numpy = "0.16.2" diff --git a/benchmarks/timeit.ipynb b/benchmarks/timeit.ipynb index 38ddd69..ba50954 100644 --- a/benchmarks/timeit.ipynb +++ b/benchmarks/timeit.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -7,7 +14,12 @@ "outputs": [], "source": [ "import fast_stats\n", - "from sklearn.metrics import precision_score, recall_score, f1_score\n", + "from sklearn.metrics import (\n", + " precision_score, \n", + " recall_score, \n", + " f1_score, \n", + " confusion_matrix\n", + ")\n", "import numpy as np" ] }, @@ -15,7 +27,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Compared to scikit-learn" + "### Settings" ] }, { @@ -24,146 +36,163 @@ "metadata": {}, "outputs": [], "source": [ - "pred = np.random.randint(0, 2, (10, 512, 512)).flatten()\n", - "actual = np.random.randint(0, 2, (10, 512, 512)).flatten()" + "SIZE = (10, 512, 512)\n", + "NUM_CATS = 20" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Binary compared to scikit-learn" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, + "outputs": [], + "source": [ + "y_true = np.random.randint(0, 2, SIZE).astype(bool).flatten()\n", + "y_pred = np.random.randint(0, 2, SIZE).astype(bool).flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "606 ms ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "577 ms ± 3.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", - "_ = precision_score(actual, pred)" + "_ = precision_score(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "608 ms ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "577 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", - "_ = recall_score(actual, pred)" + "_ = recall_score(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "601 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "579 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", - "_ = f1_score(actual, pred)" + "_ = f1_score(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "# don't actually need to flatten them for fast-stats\n", - "pred = np.random.randint(0, 2, (10, 512, 512))\n", - "actual = np.random.randint(0, 2, (10, 512, 512))" + "# don't need to flatten them for fast-stats\n", + "y_true = np.random.randint(0, 2, SIZE).astype(bool)\n", + "y_pred = np.random.randint(0, 2, SIZE).astype(bool)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "7.79 ms ± 49.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.24 ms ± 46.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.binary_precision(actual, pred)" + "_ = fast_stats.binary_precision(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "7.27 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.12 ms ± 73 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.binary_recall(actual, pred)" + "_ = fast_stats.binary_recall(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "9.89 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.43 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.binary_f1_score(actual, pred)" + "_ = fast_stats.binary_f1_score(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "assert np.allclose(\n", - " fast_stats.binary_precision(actual.flatten(), pred.flatten()),\n", - " precision_score(actual.flatten(), pred.flatten())\n", + " fast_stats.binary_precision(y_true, y_pred),\n", + " precision_score(y_true.flatten(), y_pred.flatten())\n", ")\n", "assert np.allclose(\n", - " fast_stats.binary_recall(actual.flatten(), pred.flatten()),\n", - " recall_score(actual.flatten(), pred.flatten())\n", + " fast_stats.binary_recall(y_true, y_pred),\n", + " recall_score(y_true.flatten(), y_pred.flatten())\n", ")\n", "assert np.allclose(\n", - " fast_stats.binary_f1_score(actual.flatten(), pred.flatten()),\n", - " f1_score(actual.flatten(), pred.flatten())\n", + " fast_stats.binary_f1_score(y_true, y_pred),\n", + " f1_score(y_true.flatten(), y_pred.flatten())\n", ")" ] }, @@ -171,52 +200,223 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Compared to numpy" + "### Binary compared to numpy" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "pred, actual = pred.astype(bool), actual.astype(bool)" + "y_true, y_pred = y_true.astype(bool), y_pred.astype(bool)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "3.14 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.05 ms ± 27.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ = fast_stats.binary_precision(pred, actual)" + "_ = fast_stats.binary_precision(y_true, y_pred)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.75 ms ± 53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = np.logical_and(y_true, y_pred).sum() / y_pred.sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Confusion matrix compared to sklearn" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "y_true = np.random.randint(0, NUM_CATS, SIZE).flatten()\n", + "y_pred = np.random.randint(0, NUM_CATS, SIZE).flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "336 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = confusion_matrix(y_true, y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "176 ms ± 958 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "# while labels is optional argument providing\n", + "# labels will lead to a significant speedup\n", + "# since it will not have to be inferred\n", + "_ = confusion_matrix(y_true, y_pred, labels = list(range(NUM_CATS)))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# don't need to flatten them for fast-stats, \n", + "# this is another point for speedup depending on use-case\n", + "y_true = np.random.randint(0, NUM_CATS, SIZE)\n", + "y_pred = np.random.randint(0, NUM_CATS, SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "132 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = fast_stats.confusion_matrix(y_true, y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "67.9 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "# while labels is optional argument providing\n", + "# labels will lead to a significant speedup\n", + "# since it will not have to be inferred\n", + "_ = fast_stats.confusion_matrix(y_true, y_pred, labels=list(range(NUM_CATS)))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(\n", + " confusion_matrix(y_true.flatten(), y_pred.flatten(), labels = list(range(NUM_CATS))),\n", + " fast_stats.confusion_matrix(y_true, y_pred)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Effect of flattening or reshaping for scikit-learn" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "mat = np.random.randint(0, NUM_CATS, SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.75 ms ± 38.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = mat.flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5.53 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "316 ns ± 1.21 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%%timeit\n", - "_ =np.logical_and(actual, pred).sum() / pred.sum()" + "_ = mat.reshape(-1)" ] } ], diff --git a/fast_stats/__init__.py b/fast_stats/__init__.py index 683633a..fea71b7 100644 --- a/fast_stats/__init__.py +++ b/fast_stats/__init__.py @@ -1,7 +1,9 @@ -from .fast_stats import ( +from ._fast_stats_ext import ( _binary_f1_score_reqs, _binary_precision_reqs, _binary_recall_reqs, - _tp_fp_fn_tn, + _confusion_matrix, + _unique, ) -from .stats import binary_f1_score, binary_precision, binary_recall +from .binary import binary_f1_score, binary_precision, binary_recall +from .confusion_matrix import confusion_matrix diff --git a/fast_stats/stats.py b/fast_stats/binary.py similarity index 76% rename from fast_stats/stats.py rename to fast_stats/binary.py index 7f7985b..2b52d7d 100644 --- a/fast_stats/stats.py +++ b/fast_stats/binary.py @@ -2,15 +2,12 @@ import numpy as np -from .fast_stats import ( +from ._fast_stats_ext import ( _binary_f1_score_reqs, _binary_precision_reqs, _binary_recall_reqs, ) -# from math import isnan # for Rust returning float nan - - Result = Union[None, float] @@ -36,11 +33,14 @@ def binary_precision( y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" ) -> Result: assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" - assert isinstance(y_pred, np.ndarray) and isinstance( - y_true, np.ndarray + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] ), "y_true and y_pred must be numpy arrays" - tp, tp_fp = _binary_precision_reqs(y_true, y_pred) + tp, tp_fp, _ = _binary_precision_reqs(y_true, y_pred) return _precision(tp, tp_fp, zero_division) @@ -48,20 +48,26 @@ def binary_recall( y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" ) -> Result: assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" - assert isinstance(y_pred, np.ndarray) and isinstance( - y_true, np.ndarray + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] ), "y_true and y_pred must be numpy arrays" - tp, tp_fn = _binary_recall_reqs(y_true, y_pred) + tp, tp_fn, _ = _binary_recall_reqs(y_true, y_pred) return _recall(tp, tp_fn, zero_division) def binary_f1_score( y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" -): +) -> Result: assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" - assert isinstance(y_pred, np.ndarray) and isinstance( - y_true, np.ndarray + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] ), "y_true and y_pred must be numpy arrays" tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred) diff --git a/fast_stats/confusion_matrix.py b/fast_stats/confusion_matrix.py new file mode 100644 index 0000000..adb0a20 --- /dev/null +++ b/fast_stats/confusion_matrix.py @@ -0,0 +1,22 @@ +from typing import List + +import numpy as np + +from ._fast_stats_ext import _confusion_matrix, _unique + + +def confusion_matrix( + y_true: np.ndarray, y_pred: np.ndarray, labels: List = None +) -> np.ndarray: + assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] + ), "y_true and y_pred must be numpy arrays" + + if labels is None: + labels = sorted(list(_unique(np.concatenate([y_true, y_pred])))) + + return _confusion_matrix(y_true, y_pred, labels) diff --git a/pyproject.toml b/pyproject.toml index 8939f87..bfaf880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fast-stats" -version = "0.0.3" +version = "0.0.4" description = "A fast and simple library for calculating basic statistics" readme = "README.md" license = {text="Apache 2.0"} diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9f85dbd --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +addopts = -v -s --cov=fast_stats --cov-fail-under=90 --cov-report=term-missing \ No newline at end of file diff --git a/src/binary.rs b/src/binary.rs new file mode 100644 index 0000000..dc9f898 --- /dev/null +++ b/src/binary.rs @@ -0,0 +1,383 @@ +use numpy::*; +use pyo3::{exceptions, prelude::*}; + +/// Binary precision computational requirements +#[pyfunction] +#[pyo3(name = "_binary_precision_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +pub fn py_binary_precision_reqs( + _py: Python<'_>, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128, i128)> { + dispatch(_py, "precision", actual, pred) +} + +/// Binary recall computational requirements +#[pyfunction] +#[pyo3(name = "_binary_recall_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +pub fn py_binary_recall_reqs( + _py: Python<'_>, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128, i128)> { + dispatch(_py, "recall", actual, pred) +} + +/// Binary f1 computational requirements +#[pyfunction] +#[pyo3(name = "_binary_f1_score_reqs")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] +pub fn py_binary_f1_score_reqs( + _py: Python<'_>, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128, i128)> { + dispatch(_py, "f1", actual, pred) +} + +fn custom_sum(arr: ndarray::ArrayD) -> i128 +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + let mut sum = 0; + for row in arr.rows() { + sum = sum + row.iter().fold(0, |acc, elt| acc + elt.clone().into()); + } + sum +} + +fn binary_precision_reqs( + actual: ndarray::ArrayD, + pred: ndarray::ArrayD, +) -> (i128, i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FP + (custom_sum(actual * &pred), custom_sum(pred), 0) +} + +fn binary_recall_reqs(actual: ndarray::ArrayD, pred: ndarray::ArrayD) -> (i128, i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FN + (custom_sum(&actual * pred), custom_sum(actual), 0) +} + +fn binary_f1_score_reqs( + actual: ndarray::ArrayD, + pred: ndarray::ArrayD, +) -> (i128, i128, i128) +where + T: Clone + std::ops::Add + num_traits::Num + Into, +{ + // TP, TP + FP, TP + FN + ( + custom_sum(&actual * &pred), + custom_sum(pred), + custom_sum(actual), + ) +} + +/// dispatching +fn dispatch( + _py: Python<'_>, + stat: &str, + actual: &PyAny, + pred: &PyAny, +) -> PyResult<(i128, i128, i128)> { + // bool + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array().mapv(|e| e as u8), + j.to_owned_array().mapv(|e| e as u8), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // i8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // i16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // i32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // i64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // u8 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // u16 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // u32 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + // u64 + if let (Ok(i), Ok(j)) = ( + actual.extract::>(), + pred.extract::>(), + ) { + match stat { + "recall" => { + return Ok(binary_recall_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "precision" => { + return Ok(binary_precision_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + "f1" => { + return Ok(binary_f1_score_reqs::( + i.to_owned_array(), + j.to_owned_array(), + )); + } + _ => { + return Err(PyErr::new::( + "Internal Error: not implemented stat type", + )); + } + } + } + + Err(PyErr::new::( + "Unsupported numpy dtype", + )) +} diff --git a/src/cm.rs b/src/cm.rs new file mode 100644 index 0000000..fcd813c --- /dev/null +++ b/src/cm.rs @@ -0,0 +1,131 @@ +use numpy::*; +use pyo3::prelude::*; +use std::{collections::HashMap, iter::zip}; + +/// Confusion Matrix +#[pyfunction] +#[pyo3(name = "_confusion_matrix")] +#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: List[int], /)")] +pub fn py_confusion_matrix<'a>( + py: Python<'a>, + actual: &PyAny, + pred: &PyAny, + labels: &PyAny, +) -> PyResult<&'a PyArray2> { + dispatch(py, actual, pred, labels) +} + +pub fn confusion_matrix<'a, T>( + py: Python<'a>, + actual: &PyArrayDyn, + pred: &PyArrayDyn, + labels: Vec, +) -> &'a PyArray2 +where + T: Copy + Clone + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + let mut cm = ndarray::Array2::::from_elem((labels.len(), labels.len()), 0); + let idx_map: HashMap = + HashMap::from_iter(labels.iter().enumerate().map(|(x, y)| (*y, x))); + + for (y_pred, y_actual) in zip(pred.to_owned_array().iter(), actual.to_owned_array().iter()) { + if let (Some(ix1), Some(ix2)) = (idx_map.get(y_actual), idx_map.get(y_pred)) { + *cm.get_mut((*ix1, *ix2)).unwrap() = *cm.get_mut((*ix1, *ix2)).unwrap() + 1; + } + } + + return PyArray2::from_array(py, &cm); +} + +/// dispatching +fn dispatch<'a>( + py: Python<'a>, + actual: &PyAny, + pred: &PyAny, + labels: &PyAny, +) -> PyResult<&'a PyArray2> { + // bool + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // i8 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // i16 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // i32 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // i64 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // u8 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // u16 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // u32 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + // u64 + if let (Ok(i), Ok(j), Ok(l)) = ( + actual.extract::>(), + pred.extract::>(), + labels.extract::>(), + ) { + return Ok(confusion_matrix::(py, &i, &j, l)); + } + + Err(PyErr::new::( + "Unsupported numpy dtype", + )) +} diff --git a/src/lib.rs b/src/lib.rs index b2fd95b..8ae893b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,398 +1,22 @@ -use numpy::*; -use pyo3::{exceptions, prelude::*}; -use std::iter::zip; +use pyo3::prelude::*; -fn sum(arr: ndarray::ArrayD) -> i128 -where - T: Clone + std::ops::Add + num_traits::Num + Into, -{ - let mut sum = 0; - for row in arr.rows() { - sum = sum + row.iter().fold(0, |acc, elt| acc + elt.clone().into()); - } - sum -} - -/// Get tp, fp, fn, tn counts by looping -#[pyfunction] -#[pyo3(name = "_tp_fp_fn_tn")] -#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] -fn tp_fp_fn_tn( - _py: Python<'_>, - actual: &PyArrayDyn, - pred: &PyArrayDyn, -) -> (usize, usize, usize, usize) { - let mut _tp = 0; - let mut _fp: usize = 0; - let mut _fn: usize = 0; - let mut _tn: usize = 0; - - for (y_pred, y_actual) in zip( - pred.readonly().as_array().iter(), - actual.readonly().as_array().iter(), - ) { - if *y_pred == 1 && *y_actual == 1 { - _tp = _tp + 1; - } else if *y_pred == 1 && *y_actual == 0 { - _fp = _fp + 1; - } else if *y_pred == 0 && *y_actual == 1 { - _fn = _fn + 1; - } else { - _tn = _tn + 1; - } - } - (_tp, _fp, _fn, _tn) -} - -/// Array-based binary precision req calculating -#[pyfunction] -#[pyo3(name = "_binary_precision_reqs")] -#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] -fn py_binary_precision_reqs( - _py: Python<'_>, - actual: &PyAny, - pred: &PyAny, -) -> PyResult<(i128, i128)> { - // TODO macro this out - // bool - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - )); - } - // i8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_precision_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - - Err(PyErr::new::( - "Unsupport numpy dtype", - )) -} - -fn binary_precision_reqs(actual: ndarray::ArrayD, pred: ndarray::ArrayD) -> (i128, i128) -where - T: Clone + std::ops::Add + num_traits::Num + Into, -{ - // TP, TP + FP - (sum(actual * &pred), sum(pred)) -} - -/// Array-based binary recall req calculating -#[pyfunction] -#[pyo3(name = "_binary_recall_reqs")] -#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] -fn py_binary_recall_reqs(_py: Python<'_>, actual: &PyAny, pred: &PyAny) -> PyResult<(i128, i128)> { - // TODO macro this out - // bool - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - )); - } - // i8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_recall_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - - Err(PyErr::new::( - "Unsupport numpy dtype", - )) -} - -fn binary_recall_reqs(actual: ndarray::ArrayD, pred: ndarray::ArrayD) -> (i128, i128) -where - T: Clone + std::ops::Add + num_traits::Num + Into, -{ - // TP, TP + FN - (sum(&actual * pred), sum(actual)) -} +mod binary; +mod cm; +mod utils; -/// Array-based binary recall req calculating -#[pyfunction] -#[pyo3(name = "_binary_f1_score_reqs")] -#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, /)")] -fn py_binary_f1_score_reqs( - _py: Python<'_>, - actual: &PyAny, - pred: &PyAny, -) -> PyResult<(i128, i128, i128)> { - // TODO macro this out - - // bool - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - )); - } - - // i8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // i64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u8 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u16 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u32 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } - // u64 - if let (Ok(i), Ok(j)) = ( - actual.extract::>(), - pred.extract::>(), - ) { - return Ok(binary_f1_score_reqs::( - i.to_owned_array(), - j.to_owned_array(), - )); - } +/// A Python module implemented in Rust. +#[pymodule] +fn _fast_stats_ext(_py: Python, m: &PyModule) -> PyResult<()> { + // cm + m.add_function(wrap_pyfunction!(cm::py_confusion_matrix, m)?)?; - Err(PyErr::new::( - "Unsupport numpy dtype", - )) -} + // utils + m.add_function(wrap_pyfunction!(utils::py_unique, m)?)?; -fn binary_f1_score_reqs( - actual: ndarray::ArrayD, - pred: ndarray::ArrayD, -) -> (i128, i128, i128) -where - T: Clone + std::ops::Add + num_traits::Num + Into, -{ - // TP, TP + FP, TP + FN - (sum(&actual * &pred), sum(pred), sum(actual)) -} + // binary calcs + m.add_function(wrap_pyfunction!(binary::py_binary_precision_reqs, m)?)?; + m.add_function(wrap_pyfunction!(binary::py_binary_recall_reqs, m)?)?; + m.add_function(wrap_pyfunction!(binary::py_binary_f1_score_reqs, m)?)?; -/// A Python module implemented in Rust. -#[pymodule] -fn fast_stats(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(tp_fp_fn_tn, m)?)?; - m.add_function(wrap_pyfunction!(py_binary_precision_reqs, m)?)?; - m.add_function(wrap_pyfunction!(py_binary_recall_reqs, m)?)?; - m.add_function(wrap_pyfunction!(py_binary_f1_score_reqs, m)?)?; Ok(()) } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..0bec442 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,80 @@ +use numpy::*; +use pyo3::{prelude::*, types::PySet}; +use std::collections::HashSet; + +/// unique +#[pyfunction] +#[pyo3(name = "_unique")] +#[pyo3(text_signature = "(arr: np.ndarray, /)")] +pub fn py_unique<'a>(_py: Python<'a>, arr: &PyAny) -> PyResult<&'a PySet> { + dispatch(_py, arr) +} + +// &PyArrayDyn? +fn unique<'a, T>(py: Python<'a>, arr: numpy::PyReadonlyArrayDyn) -> PyResult<&'a PySet> +where + T: Clone + numpy::Element + std::hash::Hash + std::cmp::Eq + pyo3::ToPyObject, +{ + let mut track = HashSet::::new(); + let mut ret: Vec = vec![]; + for val in arr.readonly().as_array().iter() { + if !track.contains(val) { + track.insert(val.clone()); + ret.push(val.clone()); + } + } + + PySet::new(py, ret.as_slice()) +} + +/// dispatching +fn dispatch<'a>(py: Python<'a>, actual: &PyAny) -> PyResult<&'a PySet> { + // bool + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // i8 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // i16 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // i32 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // i64 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // u8 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // u16 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // u32 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + // u64 + if let Ok(i) = actual.extract::>() { + return unique::(py, i); + } + + Err(PyErr::new::( + "Unsupported numpy dtype", + )) +} diff --git a/tests/test_stats.py b/tests/test_binary.py similarity index 71% rename from tests/test_stats.py rename to tests/test_binary.py index 5eac344..bdd2c90 100644 --- a/tests/test_stats.py +++ b/tests/test_binary.py @@ -104,3 +104,44 @@ def test_precision(y_true, y_pred, zero_division, expected): ) def test_recall(y_true, y_pred, zero_division, expected): assert fast_stats.binary_recall(y_true, y_pred, zero_division) == expected + + +@pytest.mark.parametrize( + "y_true,y_pred,zero_division,expected", + [ + ( + np.ones(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "zero", + 0.0, + ), # all FN + ( + np.ones(4, dtype=np.uint64), + np.ones(4, dtype=np.uint64), + "zero", + 1.0, + ), # all TP + ( + np.zeros(4, dtype=np.uint64), + np.zeros(4, dtype=np.uint64), + "zero", + 0.0, + ), # No TP & No FP, & No FN + ( + np.ones(4, dtype=np.uint64), + np.array([1, 0, 0, 0], dtype=np.uint64), + "none", + 2 * (1 / 4 * 1.0) / (1 / 4 + 1.0), + ), # 1 TP & 3 FN + ( + np.array([1, 1, 0, 0], dtype=np.uint64), + np.array([0, 1, 1, 0], dtype=np.uint64), + "none", + 0.5, + ), + ], +) +def test_f1(y_true, y_pred, zero_division, expected): + assert np.allclose( + fast_stats.binary_f1_score(y_true, y_pred, zero_division), expected + ) diff --git a/tests/test_cm.py b/tests/test_cm.py new file mode 100644 index 0000000..d5a60af --- /dev/null +++ b/tests/test_cm.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from fast_stats import confusion_matrix + + +@pytest.mark.parametrize( + "y_true,y_pred,expected", + [ + ( + np.ones(4, dtype=np.uint64), + np.ones(4, dtype=np.uint64), + np.array([[4]]), + ), # all one value + ( + np.array( + [ + [1, 2], + [1, 2], + ] + ), + np.array( + [ + [1, 1], + [2, 2], + ] + ), + np.ones((2, 2)), + ), # 2x2 + ( + np.array( + [ + [1, 2, 3], + [1, 2, 3], + [1, 2, 3], + ] + ), + np.array( + [ + [1, 2, 1], + [1, 1, 1], + [1, 1, 3], + ] + ), + np.array( + [ + [3, 0, 0], + [2, 1, 0], + [2, 0, 1], + ] + ), + ), # 3x3 + ], +) +def test_precision(y_true, y_pred, expected): + assert np.allclose(confusion_matrix(y_true, y_pred), expected)