From 8c13e95a3bdd9ad31a66d459854513c9b28793ac Mon Sep 17 00:00:00 2001 From: Zachary Coleman <42484306+zachcoleman@users.noreply.github.com> Date: Sun, 29 May 2022 18:52:32 -0500 Subject: [PATCH] Multiclass Statistics Added (#12) * updating project metadata * fixing CI yaml * using venv * trying again * redoing CI * fixing tests * changing some settings * updates * fixing build * trying to fix this * fixing release * bumping version * better code organization * updates * adding initial cm impl * adding unique, dispatching pattern to Py objects, renaming ext * rustfmt * cm dispatched * rustfmt * tests and benchmarks added * bump version * 100% test coverage * updating readme * Threading enabled (#9) * bumping version * major refactor leveraging macros * bumping version and updating test * adding executed notebook * fixing performance w/ bool * multiclass implemented and tested ready for 1.0.0 --- .flake8 | 3 +- Cargo.toml | 2 +- README.md | 9 +- benchmarks/timeit.ipynb | 265 +++++++++++++++++++-- fast_stats/__init__.py | 8 +- fast_stats/binary.py | 72 +++++- fast_stats/confusion_matrix.py | 10 + fast_stats/multiclass.py | 192 +++++++++++++++ pyproject.toml | 2 +- src/cm.rs | 39 +-- src/lib.rs | 6 + src/multiclass.rs | 165 +++++++++++++ tests/conftest.py | 0 tests/test_multiclass.py | 421 +++++++++++++++++++++++++++++++++ 14 files changed, 1128 insertions(+), 66 deletions(-) create mode 100644 fast_stats/multiclass.py create mode 100644 src/multiclass.rs delete mode 100644 tests/conftest.py create mode 100644 tests/test_multiclass.py diff --git a/.flake8 b/.flake8 index 63d0365..c85af2c 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] max-line-length = 88 per-file-ignores= - **/__init__.py:F401 \ No newline at end of file + **/__init__.py:F401 + **/*.py:E731 diff --git a/Cargo.toml b/Cargo.toml index 34db3cb..65c4627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fast-stats" -version = "0.0.6" +version = "1.0.0" edition = "2021" [lib] diff --git a/README.md b/README.md index 9161633..7ab19b6 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,17 @@ [![License](https://img.shields.io/badge/license-Apache2.0-green)](./LICENSE) # fast-stats -`fast-stats` is a fast and simple library for calculating basic statistics such as: precision, recall, and f1-score. The library also supports the calculation of multi-class confusion matrices. +`fast-stats` is a fast and simple library for calculating basic statistics such as: precision, recall, and f1-score. The library also supports the calculation of confusion matrices. For examples, please look at the `benchmarks/` folder. -The project was developed using the [maturin](https://maturin.rs) framework. - -This project is still in development. +The project was developed using the [maturin](https://maturin.rs) framework. ## Installation From PyPI: ```shell pip install fast-stats ``` -Build from source + +Build from source: ``` maturin build -r -i=path/to/python pip install .../fast-stats/target/wheels/.whl diff --git a/benchmarks/timeit.ipynb b/benchmarks/timeit.ipynb index a2a11b3..91afdf8 100644 --- a/benchmarks/timeit.ipynb +++ b/benchmarks/timeit.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "SIZE = (10, 512, 512)\n", - "NUM_CATS = 20" + "NUM_CATS = 8" ] }, { @@ -66,7 +66,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "580 ms ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "580 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -84,7 +84,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "581 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "589 ms ± 5.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -102,7 +102,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "579 ms ± 5.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "585 ms ± 6.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -131,7 +131,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.27 ms ± 45.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.51 ms ± 79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -149,7 +149,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.12 ms ± 46.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.19 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -167,7 +167,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4.35 ms ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.47 ms ± 65.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -221,7 +221,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "5.7 ms ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "5.67 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -239,7 +239,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.17 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.11 ms ± 35.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -252,7 +252,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Confusion matrix compared to sklearn" + "### Multiclass compared to sklearn" ] }, { @@ -260,6 +260,221 @@ "execution_count": 15, "metadata": {}, "outputs": [], + "source": [ + "y_true = np.random.randint(0, NUM_CATS, SIZE)\n", + "y_pred = np.random.randint(0, NUM_CATS, SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "671 ms ± 3.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = precision_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "67.7 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "tmp = fast_stats.precision(\n", + " y_true,\n", + " y_pred,\n", + " list(range(NUM_CATS)),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(\n", + " precision_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + " ),\n", + " fast_stats.precision(\n", + " y_true,\n", + " y_pred,\n", + " list(range(NUM_CATS)),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "674 ms ± 3.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = recall_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "68.3 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "tmp = fast_stats.recall(\n", + " y_true,\n", + " y_pred,\n", + " np.array(list(range(NUM_CATS)), dtype=y_true.dtype)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(\n", + " recall_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + " ),\n", + " fast_stats.recall(\n", + " y_true,\n", + " y_pred,\n", + " list(range(NUM_CATS)),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "679 ms ± 3.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "_ = f1_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "66.5 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "tmp = fast_stats.f1_score(\n", + " y_true,\n", + " y_pred,\n", + " np.array(list(range(NUM_CATS)), dtype=y_true.dtype)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(\n", + " f1_score(\n", + " y_true.reshape(-1), \n", + " y_pred.reshape(-1), \n", + " labels=list(range(NUM_CATS)), \n", + " average=None\n", + " ),\n", + " fast_stats.f1_score(\n", + " y_true,\n", + " y_pred,\n", + " list(range(NUM_CATS)),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Confusion matrix compared to sklearn" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], "source": [ "y_true = np.random.randint(0, NUM_CATS, SIZE).flatten()\n", "y_pred = np.random.randint(0, NUM_CATS, SIZE).flatten()" @@ -267,14 +482,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "316 ms ± 608 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "274 ms ± 663 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -285,14 +500,14 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "169 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "149 ms ± 655 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -306,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -318,14 +533,14 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "142 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "135 ms ± 824 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -336,14 +551,14 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "72 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "67.5 ms ± 560 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -357,7 +572,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -376,7 +591,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -385,14 +600,14 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1.69 ms ± 32.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + "1.71 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -403,14 +618,14 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "326 ns ± 0.755 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" + "325 ns ± 1.53 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], diff --git a/fast_stats/__init__.py b/fast_stats/__init__.py index fea71b7..3c6e8be 100644 --- a/fast_stats/__init__.py +++ b/fast_stats/__init__.py @@ -1,9 +1,3 @@ -from ._fast_stats_ext import ( - _binary_f1_score_reqs, - _binary_precision_reqs, - _binary_recall_reqs, - _confusion_matrix, - _unique, -) from .binary import binary_f1_score, binary_precision, binary_recall from .confusion_matrix import confusion_matrix +from .multiclass import f1_score, precision, recall diff --git a/fast_stats/binary.py b/fast_stats/binary.py index 2b52d7d..cf3eecf 100644 --- a/fast_stats/binary.py +++ b/fast_stats/binary.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Union import numpy as np @@ -11,27 +12,47 @@ Result = Union[None, float] -def _precision(tp: int, tp_fp: int, zero_division: str = "none") -> Result: +class ZeroDivision(Enum): + ZERO = "zero" + NONE = "none" + + +def _precision( + tp: int, tp_fp: int, zero_division: ZeroDivision = ZeroDivision.NONE +) -> Result: if tp_fp == 0: - if zero_division == "none": + if zero_division == ZeroDivision.NONE: return None - elif zero_division == "zero": + elif zero_division == ZeroDivision.ZERO: return 0.0 return tp / tp_fp -def _recall(tp: int, tp_fn: int, zero_division: str = "none") -> Result: +def _recall( + tp: int, tp_fn: int, zero_division: ZeroDivision = ZeroDivision.NONE +) -> Result: if tp_fn == 0: - if zero_division == "none": + if zero_division == ZeroDivision.NONE: return None - elif zero_division == "zero": + elif zero_division == ZeroDivision.ZERO: return 0.0 return tp / tp_fn def binary_precision( - y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" + y_true: np.ndarray, + y_pred: np.ndarray, + zero_division: ZeroDivision = ZeroDivision.NONE, ) -> Result: + """Binary calculation for precision + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + zero_division (str): determines how to handle division by zero + Returns: + Result: None or float depending on values and zero division + """ assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" assert all( [ @@ -39,14 +60,26 @@ def binary_precision( isinstance(y_true, np.ndarray), ] ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) tp, tp_fp, _ = _binary_precision_reqs(y_true, y_pred) return _precision(tp, tp_fp, zero_division) def binary_recall( - y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" + y_true: np.ndarray, + y_pred: np.ndarray, + zero_division: ZeroDivision = ZeroDivision.NONE, ) -> Result: + """Binary calculation for recall + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + zero_division (str): determines how to handle division by zero + Returns: + Result: None or float depending on values and zero division + """ assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" assert all( [ @@ -54,14 +87,26 @@ def binary_recall( isinstance(y_true, np.ndarray), ] ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) tp, tp_fn, _ = _binary_recall_reqs(y_true, y_pred) return _recall(tp, tp_fn, zero_division) def binary_f1_score( - y_true: np.ndarray, y_pred: np.ndarray, zero_division: str = "none" + y_true: np.ndarray, + y_pred: np.ndarray, + zero_division: ZeroDivision = ZeroDivision.NONE, ) -> Result: + """Binary calculation for f1 score + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + zero_division (str): determines how to handle division by zero + Returns: + Result: None or float depending on values and zero division + """ assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" assert all( [ @@ -69,14 +114,17 @@ def binary_f1_score( isinstance(y_true, np.ndarray), ] ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred) - p, r = _precision(tp, tp_fp, "zero"), _recall(tp, tp_fn, "zero") + p, r = _precision(tp, tp_fp, zero_division.ZERO), _recall( + tp, tp_fn, zero_division.ZERO + ) if p + r == 0: - if zero_division == "none": + if zero_division == ZeroDivision.NONE: return None - elif zero_division == "zero": + elif zero_division == ZeroDivision.ZERO: return 0.0 return 2 * p * r / (p + r) diff --git a/fast_stats/confusion_matrix.py b/fast_stats/confusion_matrix.py index 72a2fe9..daa091b 100644 --- a/fast_stats/confusion_matrix.py +++ b/fast_stats/confusion_matrix.py @@ -8,6 +8,16 @@ def confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, labels: Union[List, np.ndarray] = None ) -> np.ndarray: + """Calculation of confusion matrix + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + labels (optional | list or np.ndarray): + labels to calculate confusion matrix for (must be bool or int types) + Returns: + confusion matrix (np.ndarray): 2D np.ndarray confusion matrix + """ assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" assert all( [ diff --git a/fast_stats/multiclass.py b/fast_stats/multiclass.py new file mode 100644 index 0000000..8984831 --- /dev/null +++ b/fast_stats/multiclass.py @@ -0,0 +1,192 @@ +from enum import Enum +from functools import partial +from typing import List, Union + +import numpy as np + +from ._fast_stats_ext import _f1_score, _precision, _recall, _unique + +Result = Union[None, float, np.ndarray] + + +class ZeroDivision(Enum): + ZERO = "zero" + NONE = "none" + + +class AverageType(Enum): + NONE = "none" + MICRO = "micro" + MACRO = "macro" + + +def precision( + y_true: np.ndarray, + y_pred: np.ndarray, + labels: Union[List, np.ndarray] = None, + zero_division: ZeroDivision = ZeroDivision.NONE, + average: AverageType = AverageType.NONE, +) -> Result: + """Multi-class calculation of precision + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + labels (optional | list or np.ndarray): + labels to calculate confusion matrix for (must be bool or int types) + zero_division (optional | str): strategy to handle division by 0 + average (optional | str): strategy for averaging across classes + Returns: + precision (np.ndarray): 1D array or scalar values depending on averaging + """ + assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] + ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) + average = AverageType(average) + + if labels is None: + labels = np.array( + sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + ) + elif isinstance(labels, list): + labels = np.array(labels, dtype=y_true.dtype) + + x = _precision(y_true, y_pred, labels) + + if zero_division == ZeroDivision.NONE: + zero_handle = partial( + np.nan_to_num, copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan + ) + elif zero_division == zero_division.ZERO: + zero_handle = partial( + np.nan_to_num, copy=False, nan=0.0, posinf=0.0, neginf=0.0 + ) + + with np.errstate(divide="ignore", invalid="ignore"): + if average == AverageType.NONE: + return zero_handle(x[:, 0] / x[:, 1]) + elif average == AverageType.MICRO: + return zero_handle(x[:, 0].sum() / x[:, 1].sum()) + elif average == AverageType.MACRO: + return np.nanmean(zero_handle(x[:, 0] / x[:, 1])) + + +def recall( + y_true: np.ndarray, + y_pred: np.ndarray, + labels: Union[List, np.ndarray] = None, + zero_division: ZeroDivision = ZeroDivision.NONE, + average: AverageType = AverageType.NONE, +) -> Result: + """Multi-class calculation of recall + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + labels (optional | list or np.ndarray): + labels to calculate confusion matrix for (must be bool or int types) + zero_division (optional | str): strategy to handle division by 0 + average (optional | str): strategy for averaging across classes + Returns: + recall (np.ndarray): 1D array or scalar values depending on averaging + """ + assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] + ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) + average = AverageType(average) + + if labels is None: + labels = np.array( + sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + ) + elif isinstance(labels, list): + labels = np.array(labels, dtype=y_true.dtype) + + x = _recall(y_true, y_pred, labels) + + if zero_division == ZeroDivision.NONE: + zero_handle = partial( + np.nan_to_num, copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan + ) + elif zero_division == zero_division.ZERO: + zero_handle = partial( + np.nan_to_num, copy=False, nan=0.0, posinf=0.0, neginf=0.0 + ) + + with np.errstate(divide="ignore", invalid="ignore"): + if average == AverageType.NONE: + return zero_handle(x[:, 0] / x[:, 1]) + elif average == AverageType.MICRO: + return zero_handle(x[:, 0].sum() / x[:, 1].sum()) + elif average == AverageType.MACRO: + return np.nanmean(zero_handle(x[:, 0] / x[:, 1])) + + +def f1_score( + y_true: np.ndarray, + y_pred: np.ndarray, + labels: Union[List, np.ndarray] = None, + zero_division: ZeroDivision = ZeroDivision.NONE, + average: AverageType = AverageType.NONE, +) -> Result: + """Multi-class calculation of f1 score + + Args: + y_true (np.ndarray): array of true values (must be bool or int types) + y_pred (np.ndarray): array of pred values (must be bool or int types) + labels (optional | list or np.ndarray): + labels to calculate confusion matrix for (must be bool or int types) + zero_division (optional | str): strategy to handle division by 0 + average (optional | str): strategy for averaging across classes + Returns: + f1 score (np.ndarray): 1D array or scalar values depending on averaging + """ + assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape" + assert all( + [ + isinstance(y_pred, np.ndarray), + isinstance(y_true, np.ndarray), + ] + ), "y_true and y_pred must be numpy arrays" + zero_division = ZeroDivision(zero_division) + average = AverageType(average) + + if labels is None: + labels = np.array( + sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + ) + elif isinstance(labels, list): + labels = np.array(labels, dtype=y_true.dtype) + + x = _f1_score(y_true, y_pred, labels) + + if zero_division == ZeroDivision.NONE: + zero_handle = partial( + np.nan_to_num, copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan + ) + elif zero_division == zero_division.ZERO: + zero_handle = partial( + np.nan_to_num, copy=False, nan=0.0, posinf=0.0, neginf=0.0 + ) + + def f1_from_ext(x, y, z): + p, r = x / y, x / z + return 2 * p * r / (p + r) + + with np.errstate(divide="ignore", invalid="ignore"): + if average == AverageType.NONE: + return zero_handle(f1_from_ext(x[:, 0], x[:, 1], x[:, 2])) + elif average == AverageType.MICRO: + return zero_handle(f1_from_ext(x[:, 0].sum(), x[:, 1].sum(), x[:, 2].sum())) + elif average == AverageType.MACRO: + return np.nanmean(f1_from_ext(x[:, 0], x[:, 1], x[:, 2])) diff --git a/pyproject.toml b/pyproject.toml index 81efcea..26ab1d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fast-stats" -version = "0.0.6" +version = "1.0.0" description = "A fast and simple library for calculating basic statistics" readme = "README.md" license = {text="Apache 2.0"} diff --git a/src/cm.rs b/src/cm.rs index c673fd2..a01b45b 100644 --- a/src/cm.rs +++ b/src/cm.rs @@ -1,4 +1,4 @@ -use ndarray::ArrayD; +use ndarray::*; use numpy::*; use pyo3::prelude::*; use std::{collections::HashMap, iter::zip}; @@ -8,7 +8,9 @@ use crate::numpy_dispatch_bool; /// Confusion Matrix #[pyfunction] #[pyo3(name = "_confusion_matrix")] -#[pyo3(text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: List[int], /)")] +#[pyo3( + text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: Union[List, np.ndarray], /)" +)] pub fn py_confusion_matrix<'a>( py: Python<'a>, actual: &'a PyAny, @@ -39,19 +41,28 @@ where let labels = labels.to_vec().unwrap(); let threadable = |actual: ArrayD, pred: ArrayD| -> ndarray::Array2 { - py.allow_threads(move || { - let mut cm = ndarray::Array2::::from_elem((labels.len(), labels.len()), 0); - let idx_map: HashMap = - HashMap::from_iter(labels.iter().enumerate().map(|(x, y)| (*y, x))); - - for (y_pred, y_actual) in zip(pred.iter(), actual.iter()) { - if let (Some(ix1), Some(ix2)) = (idx_map.get(y_actual), idx_map.get(y_pred)) { - *cm.get_mut((*ix1, *ix2)).unwrap() = *cm.get_mut((*ix1, *ix2)).unwrap() + 1; - } - } - return cm; - }) + py.allow_threads(move || return confusion_matrix_owned(actual, pred, labels)) }; Ok(PyArray2::from_array(py, &threadable(actual, pred))) } + +pub fn confusion_matrix_owned( + actual: ndarray::ArrayD, + pred: ndarray::ArrayD, + labels: Vec, +) -> ndarray::Array2 +where + T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + let mut cm = ndarray::Array2::::from_elem((labels.len(), labels.len()), 0); + let idx_map: HashMap = + HashMap::from_iter(labels.iter().enumerate().map(|(x, y)| (*y, x))); + + for (y_pred, y_actual) in zip(pred.iter(), actual.iter()) { + if let (Some(ix1), Some(ix2)) = (idx_map.get(y_actual), idx_map.get(y_pred)) { + *cm.get_mut((*ix1, *ix2)).unwrap() = *cm.get_mut((*ix1, *ix2)).unwrap() + 1; + } + } + cm +} diff --git a/src/lib.rs b/src/lib.rs index 2c9263f..93d6fbf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; mod binary; mod cm; mod dispatch; +mod multiclass; mod utils; /// A Python module implemented in Rust. @@ -19,5 +20,10 @@ fn _fast_stats_ext(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(binary::py_binary_recall_reqs, m)?)?; m.add_function(wrap_pyfunction!(binary::py_binary_f1_score_reqs, m)?)?; + // multiclass calcs + m.add_function(wrap_pyfunction!(multiclass::py_precision, m)?)?; + m.add_function(wrap_pyfunction!(multiclass::py_recall, m)?)?; + m.add_function(wrap_pyfunction!(multiclass::py_f1_score, m)?)?; + Ok(()) } diff --git a/src/multiclass.rs b/src/multiclass.rs new file mode 100644 index 0000000..051b594 --- /dev/null +++ b/src/multiclass.rs @@ -0,0 +1,165 @@ +use numpy::*; +use pyo3::prelude::*; + +use crate::cm; +use crate::numpy_dispatch_bool; + +/// Precision computational requirements +#[pyfunction] +#[pyo3(name = "_precision")] +#[pyo3( + text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: Union[List, np.ndarray], /)" +)] +pub fn py_precision<'a>( + py: Python<'a>, + actual: &'a PyAny, + pred: &'a PyAny, + labels: &'a PyAny, +) -> PyResult<&'a PyArray2> { + numpy_dispatch_bool!( + py, + precision, + PyResult<&'a PyArray2>, + actual, + pred, + labels + ) +} + +/// Recall computational requirements +#[pyfunction] +#[pyo3(name = "_recall")] +#[pyo3( + text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: Union[List, np.ndarray], /)" +)] +pub fn py_recall<'a>( + py: Python<'a>, + actual: &'a PyAny, + pred: &'a PyAny, + labels: &'a PyAny, +) -> PyResult<&'a PyArray2> { + numpy_dispatch_bool!( + py, + recall, + PyResult<&'a PyArray2>, + actual, + pred, + labels + ) +} + +/// f1 score computational requirements +#[pyfunction] +#[pyo3(name = "_f1_score")] +#[pyo3( + text_signature = "(actual: np.ndarray, pred: np.ndarray, labels: Union[List, np.ndarray], /)" +)] +pub fn py_f1_score<'a>( + py: Python<'a>, + actual: &'a PyAny, + pred: &'a PyAny, + labels: &'a PyAny, +) -> PyResult<&'a PyArray2> { + numpy_dispatch_bool!( + py, + f1_score, + PyResult<&'a PyArray2>, + actual, + pred, + labels + ) +} + +fn precision<'a, T>( + py: Python<'a>, + actual: PyReadonlyArrayDyn, + pred: PyReadonlyArrayDyn, + labels: PyReadonlyArrayDyn, +) -> PyResult<&'a PyArray2> +where + T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + let actual = actual.to_owned_array(); + let pred = pred.to_owned_array(); + let labels = labels.to_vec().unwrap(); + + let threadable = + |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { + py.allow_threads(move || { + let cm = cm::confusion_matrix_owned(actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); + for (idx, col) in cm.columns().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); + // get TP + FP + *ret.get_mut((idx, 1)).unwrap() = col.sum(); + } + return ret; + }) + }; + Ok(PyArray2::from_array(py, &threadable(actual, pred))) +} + +fn recall<'a, T>( + py: Python<'a>, + actual: PyReadonlyArrayDyn, + pred: PyReadonlyArrayDyn, + labels: PyReadonlyArrayDyn, +) -> PyResult<&'a PyArray2> +where + T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + let actual = actual.to_owned_array(); + let pred = pred.to_owned_array(); + let labels = labels.to_vec().unwrap(); + + let threadable = + |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { + py.allow_threads(move || { + let cm = cm::confusion_matrix_owned(actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); + for (idx, row) in cm.rows().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *row.get(idx).unwrap(); + // get TP + FN + *ret.get_mut((idx, 1)).unwrap() = row.sum(); + } + return ret; + }) + }; + Ok(PyArray2::from_array(py, &threadable(actual, pred))) +} + +fn f1_score<'a, T>( + py: Python<'a>, + actual: PyReadonlyArrayDyn, + pred: PyReadonlyArrayDyn, + labels: PyReadonlyArrayDyn, +) -> PyResult<&'a PyArray2> +where + T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + let actual = actual.to_owned_array(); + let pred = pred.to_owned_array(); + let labels = labels.to_vec().unwrap(); + + let threadable = + |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { + py.allow_threads(move || { + let cm = cm::confusion_matrix_owned(actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 3), 0); + for (idx, col) in cm.columns().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); + // get TP + FP + *ret.get_mut((idx, 1)).unwrap() = col.sum(); + } + for (idx, row) in cm.rows().into_iter().enumerate() { + // get TP + FN + *ret.get_mut((idx, 2)).unwrap() = row.sum(); + } + return ret; + }) + }; + Ok(PyArray2::from_array(py, &threadable(actual, pred))) +} diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_multiclass.py b/tests/test_multiclass.py new file mode 100644 index 0000000..2c69efc --- /dev/null +++ b/tests/test_multiclass.py @@ -0,0 +1,421 @@ +import numpy as np +import pytest + +import fast_stats + + +@pytest.mark.parametrize( + "y_true,y_pred,kwargs,expected", + [ + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + {}, + np.array([1.0, 1.0, 1.0]), + ), # perfect + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2, 3]}, + np.array([0.5, 0.5, 0.5]), + ), # 50% + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 1.0, + 3.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2]}, + np.array([0.5, 0.5]), + ), # 50% subset + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "micro", "zero_division": "none"}, + np.array([0.5]), + ), # 50% micro + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "macro", "zero_division": "zero"}, + np.array([0.5]), + ), # 50% macro + ], +) +def test_precision(y_true, y_pred, kwargs, expected): + assert np.allclose(fast_stats.precision(y_true, y_pred, **kwargs), expected) + + +@pytest.mark.parametrize( + "y_true,y_pred,kwargs,expected", + [ + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + {}, + np.array([1.0, 1.0, 1.0]), + ), # perfect + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2, 3]}, + np.array([0.5, 0.5, 0.5]), + ), # 50% + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 1.0, + 3.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2]}, + np.array([0.5, 0.5]), + ), # 50% subset + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "micro", "zero_division": "none"}, + np.array([0.5]), + ), # 50% micro + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "macro", "zero_division": "zero"}, + np.array([0.5]), + ), # 50% macro + ], +) +def test_recall(y_true, y_pred, kwargs, expected): + assert np.allclose(fast_stats.recall(y_true, y_pred, **kwargs), expected) + + +@pytest.mark.parametrize( + "y_true,y_pred,kwargs,expected", + [ + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + {}, + np.array([1.0, 1.0, 1.0]), + ), # perfect + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2, 3]}, + np.array([0.5, 0.5, 0.5]), + ), # 50% + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 1.0, + 3.0, + ], + dtype=np.uint64, + ), + {"labels": [1, 2]}, + np.array([0.5, 0.5]), + ), # 50% subset + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "micro", "zero_division": "none"}, + np.array([0.5]), + ), # 50% micro + ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + ], + dtype=np.uint64, + ), + np.array( + [ + 1.0, + 2.0, + 3.0, + 2.0, + 3.0, + 1.0, + ], + dtype=np.uint64, + ), + {"average": "macro", "zero_division": "zero"}, + np.array([0.5]), + ), # 50% macro + ], +) +def test_f1_score(y_true, y_pred, kwargs, expected): + assert np.allclose(fast_stats.f1_score(y_true, y_pred, **kwargs), expected)