From 45813c172d8e2f81514811bad3b7e2e8a2bf6020 Mon Sep 17 00:00:00 2001 From: Zachary Coleman <42484306+zachcoleman@users.noreply.github.com> Date: Wed, 28 Dec 2022 19:46:40 -0700 Subject: [PATCH] Release 1.3.0 (#30) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * updating project metadata * fixing CI yaml * using venv * trying again * redoing CI * fixing tests * changing some settings * updates * fixing build * trying to fix this * fixing release * bumping version * better code organization * updates * adding initial cm impl * adding unique, dispatching pattern to Py objects, renaming ext * rustfmt * cm dispatched * rustfmt * tests and benchmarks added * bump version * 100% test coverage * updating readme * Threading enabled (#9) * bumping version * major refactor leveraging macros * bumping version and updating test * adding executed notebook * fixing performance w/ bool * multiclass implemented and tested ready for 1.0.0 * shifting to u32 for numpy compatability * bumping version * changing to i64 for better compatability * Major additions to Python API to include high-level helpers * changing to dev deps * ooops * bandit! (#25) * Add 3.11 (#26) * add 3.11 * add 3.11 to tests * Adding `mypy` type checking (#27) * passing mypy type checking * not needed * adding mypy to actions * bumping to 1.2.1 * Zero/Constant Memory Calls (#29) * testing * continued testing * progress * switching to iterating w/ no duplicated memory * removing unsafe code * Bump version: 1.2.1 → 1.3.0 * adding bump2version --- .bumpversion.cfg | 12 ++ .gitignore | 3 +- Cargo.toml | 2 +- examples/benchmarks.ipynb | 58 ++++---- examples/stats.ipynb | 102 ++++++++++---- examples/threading.ipynb | 120 ---------------- fast_stats/confusion_matrix.py | 2 +- fast_stats/multiclass.py | 8 +- pyproject.toml | 3 +- src/binary.rs | 249 ++++++++++++++++----------------- src/cm.rs | 53 ++++--- src/multiclass.rs | 92 ++++-------- src/utils.rs | 32 ++--- 13 files changed, 314 insertions(+), 422 deletions(-) create mode 100644 .bumpversion.cfg delete mode 100644 examples/threading.ipynb diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 0000000..451837d --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,12 @@ +[bumpversion] +current_version = 1.3.0 +commit = True +tag = False + +[bumpversion:file:pyproject.toml] +search = version = "{current_version}" +replace = version = "{new_version}" + +[bumpversion:file:Cargo.toml] +search = version = "{current_version}" +replace = version = "{new_version}" diff --git a/.gitignore b/.gitignore index 10b06fa..1b2ee63 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,5 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb -.DS_Store \ No newline at end of file +.DS_Store +fil-result/ diff --git a/Cargo.toml b/Cargo.toml index 1f5ce73..62f8ea6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fast-stats" -version = "1.2.1" +version = "1.3.0" edition = "2021" [lib] diff --git a/examples/benchmarks.ipynb b/examples/benchmarks.ipynb index 2d8163e..347db35 100644 --- a/examples/benchmarks.ipynb +++ b/examples/benchmarks.ipynb @@ -14,13 +14,13 @@ "outputs": [], "source": [ "import fast_stats\n", + "import numpy as np\n", "from sklearn.metrics import (\n", " precision_score, \n", " recall_score, \n", " f1_score, \n", " confusion_matrix\n", - ")\n", - "import numpy as np" + ")" ] }, { @@ -66,7 +66,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "574 ms ± 7.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "569 ms ± 6.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -84,7 +84,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "576 ms ± 9.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "566 ms ± 925 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -102,7 +102,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "585 ms ± 8.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "567 ms ± 2.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -131,7 +131,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.39 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.91 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -149,7 +149,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.4 ms ± 165 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.93 ms ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -167,7 +167,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4.37 ms ± 93.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.42 ms ± 7.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -221,7 +221,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "5.73 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.49 ms ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -239,7 +239,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.15 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.91 ms ± 3.14 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -274,7 +274,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "684 ms ± 7.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "677 ms ± 1.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -297,7 +297,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "67.3 ms ± 848 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "56.7 ms ± 831 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -340,7 +340,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "683 ms ± 8.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "686 ms ± 5.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -363,7 +363,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "67.9 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "57 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -406,7 +406,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "685 ms ± 10.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "684 ms ± 6.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -429,7 +429,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "69 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "56.7 ms ± 679 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -540,7 +540,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "277 ms ± 6.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "258 ms ± 4.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -558,7 +558,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "144 ms ± 768 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "133 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -591,7 +591,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "132 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "107 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -609,7 +609,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "68.2 ms ± 539 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "57.5 ms ± 669 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -658,7 +658,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.87 ms ± 49.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + "1.7 ms ± 27.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -676,7 +676,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "318 ns ± 0.241 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" + "107 ns ± 0.273 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], @@ -687,11 +687,8 @@ } ], "metadata": { - "interpreter": { - "hash": "a3a671d63c09fb4878d313d605bf6366336b9695c04e11736a5d015abf9b1e42" - }, "kernelspec": { - "display_name": "Python 3.9.11 ('.venv39': venv)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -705,9 +702,14 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.11" + "version": "3.10.8" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "2994604185f86bc96c0c5cb0b57fabd703f5f0106f9413ed938cdec2350fab6e" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/examples/stats.ipynb b/examples/stats.ipynb index 3f0ea7b..509fb75 100644 --- a/examples/stats.ipynb +++ b/examples/stats.ipynb @@ -12,6 +12,15 @@ "execution_count": 1, "metadata": {}, "outputs": [], + "source": [ + "%load_ext filprofiler" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], "source": [ "import numpy as np\n", "import fast_stats" @@ -26,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -53,23 +62,42 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "=fil-profile= Preparing to write to fil-result/tmpzxndpb6p\n", + "=fil-profile= Wrote flamegraph to \"fil-result/tmpzxndpb6p/peak-memory.svg\"\n", + "=fil-profile= Wrote flamegraph to \"fil-result/tmpzxndpb6p/peak-memory-reversed.svg\"\n" + ] + }, { "data": { + "text/html": [ + "\n", + " \n", + " " + ], "text/plain": [ - "{'precision': 0.49939381724124243,\n", - " 'recall': 0.4994250828781588,\n", - " 'f1-score': 0.4994094495703526}" + "" ] }, - "execution_count": 4, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ + "%%filprofile\n", "fast_stats.binary_stats(y_true, y_pred)" ] }, @@ -82,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -92,38 +120,58 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "=fil-profile= Preparing to write to fil-result/tmpcg_juijl\n", + "=fil-profile= Wrote flamegraph to \"fil-result/tmpcg_juijl/peak-memory.svg\"\n", + "=fil-profile= Wrote flamegraph to \"fil-result/tmpcg_juijl/peak-memory-reversed.svg\"\n" + ] + }, { "data": { + "text/html": [ + "\n", + " \n", + " " + ], "text/plain": [ - "{'precision': array([0.1256168 , 0.12500038, 0.12486642, 0.1248673 , 0.12500914,\n", - " 0.12636344, 0.12454387, 0.12488666]),\n", - " 'recall': array([0.12568051, 0.12524121, 0.1245743 , 0.12500458, 0.12535152,\n", - " 0.12616009, 0.12444887, 0.12469365]),\n", - " 'f1-score': array([0.12564865, 0.12512068, 0.12472019, 0.1249359 , 0.1251801 ,\n", - " 0.12626169, 0.12449635, 0.12479008])}" + "" ] }, - "execution_count": 6, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ + "%%filprofile\n", "fast_stats.stats(y_true, y_pred)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { - "interpreter": { - "hash": "a3a671d63c09fb4878d313d605bf6366336b9695c04e11736a5d015abf9b1e42" - }, "kernelspec": { - "display_name": "Python 3.9.11 ('.venv39': venv)", + "display_name": "Python 3 with Fil", "language": "python", - "name": "python3" + "name": "filprofile" }, "language_info": { "codemirror_mode": { @@ -135,9 +183,13 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.11" + "version": "3.10.8" }, - "orig_nbformat": 4 + "vscode": { + "interpreter": { + "hash": "2994604185f86bc96c0c5cb0b57fabd703f5f0106f9413ed938cdec2350fab6e" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/examples/threading.ipynb b/examples/threading.ipynb deleted file mode 100644 index e91114e..0000000 --- a/examples/threading.ipynb +++ /dev/null @@ -1,120 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from concurrent.futures import ThreadPoolExecutor\n", - "\n", - "import fast_stats\n", - "import numpy as np" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Settings" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "SIZE = (10, 2048, 2048)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "y_true = np.random.randint(0, 2, SIZE).astype(bool)\n", - "y_pred = np.random.randint(0, 2, SIZE).astype(bool)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Threaded timing" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "462 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "for _ in range(10):\n", - " _ = fast_stats.binary_precision(y_true, y_pred)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "258 ms ± 3.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "with ThreadPoolExecutor(2) as ex:\n", - " for _ in range(10):\n", - " _ = ex.submit(fast_stats.binary_precision, y_true, y_pred)" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "a3a671d63c09fb4878d313d605bf6366336b9695c04e11736a5d015abf9b1e42" - }, - "kernelspec": { - "display_name": "Python 3.9.11 ('.venv39': venv)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.11" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_stats/confusion_matrix.py b/fast_stats/confusion_matrix.py index d81e5df..ce8c41d 100644 --- a/fast_stats/confusion_matrix.py +++ b/fast_stats/confusion_matrix.py @@ -33,7 +33,7 @@ def confusion_matrix( if labels is None: labels = np.array( - sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + sorted(list(_unique(y_true).union(_unique(y_pred)))), dtype=y_true.dtype ) elif isinstance(labels, list): labels = np.array(labels, dtype=y_true.dtype) diff --git a/fast_stats/multiclass.py b/fast_stats/multiclass.py index 71d6461..dc077d6 100644 --- a/fast_stats/multiclass.py +++ b/fast_stats/multiclass.py @@ -69,7 +69,7 @@ def precision( if labels is None: labels = np.array( - sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + sorted(list(_unique(y_true).union(_unique(y_pred)))), dtype=y_true.dtype ) elif isinstance(labels, list): labels = np.array(labels, dtype=y_true.dtype) @@ -119,7 +119,7 @@ def recall( if labels is None: labels = np.array( - sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + sorted(list(_unique(y_true).union(_unique(y_pred)))), dtype=y_true.dtype ) elif isinstance(labels, list): labels = np.array(labels, dtype=y_true.dtype) @@ -169,7 +169,7 @@ def f1_score( if labels is None: labels = np.array( - sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + sorted(list(_unique(y_true).union(_unique(y_pred)))), dtype=y_true.dtype ) elif isinstance(labels, list): labels = np.array(labels, dtype=y_true.dtype) @@ -225,7 +225,7 @@ def stats( if labels is None: labels = np.array( - sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype + sorted(list(_unique(y_true).union(_unique(y_pred)))), dtype=y_true.dtype ) elif isinstance(labels, list): labels = np.array(labels, dtype=y_true.dtype) diff --git a/pyproject.toml b/pyproject.toml index e775935..3d692ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fast-stats" -version = "1.2.1" +version = "1.3.0" description = "A fast and simple library for calculating basic statistics" readme = "README.md" license = {text="Apache 2.0"} @@ -30,6 +30,7 @@ repository = "https://github.com/zachcoleman/fast-stats" [project.optional-dependencies] dev = [ + "bump2version", "dictdiffer", "pre-commit", "pytest", diff --git a/src/binary.rs b/src/binary.rs index f8ab731..5ecb92b 100644 --- a/src/binary.rs +++ b/src/binary.rs @@ -16,20 +16,16 @@ pub fn py_binary_precision_reqs<'a>( actual.extract::>(), pred.extract::>(), ) { - return binary_precision_reqs_owned::( + binary_precision_reqs_bool(py, i, j) + } else { + numpy_dispatch_no_bool!( py, - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - ); + binary_precision_reqs, + PyResult<(i128, i128, i128)>, + actual, + pred + ) } - - numpy_dispatch_no_bool!( - py, - binary_precision_reqs, - PyResult<(i128, i128, i128)>, - actual, - pred - ) } /// Binary recall computational requirements @@ -45,20 +41,16 @@ pub fn py_binary_recall_reqs<'a>( actual.extract::>(), pred.extract::>(), ) { - return binary_recall_reqs_owned::( + binary_recall_reqs_bool(py, i, j) + } else { + numpy_dispatch_no_bool!( py, - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - ); + binary_recall_reqs, + PyResult<(i128, i128, i128)>, + actual, + pred + ) } - - numpy_dispatch_no_bool!( - py, - binary_recall_reqs, - PyResult<(i128, i128, i128)>, - actual, - pred - ) } /// Binary f1 computational requirements @@ -74,36 +66,20 @@ pub fn py_binary_f1_score_reqs<'a>( actual.extract::>(), pred.extract::>(), ) { - return binary_f1_score_reqs_owned::( + binary_f1_score_reqs_bool(py, i, j) + } else { + numpy_dispatch_no_bool!( py, - i.to_owned_array().mapv(|e| e as u8), - j.to_owned_array().mapv(|e| e as u8), - ); - } - - numpy_dispatch_no_bool!( - py, - binary_f1_score_reqs, - PyResult<(i128, i128, i128)>, - actual, - pred - ) -} - -// move into utils? -fn custom_sum(arr: ndarray::ArrayD) -> i128 -where - T: Clone + std::ops::Add + num_traits::Num + Into, -{ - let mut sum = 0; - for row in arr.rows() { - sum = sum + row.iter().fold(0, |acc, elt| acc + elt.clone().into()); + binary_f1_score_reqs, + PyResult<(i128, i128, i128)>, + actual, + pred + ) } - sum } fn binary_precision_reqs<'a, T>( - py: Python<'a>, + _py: Python<'a>, actual: numpy::PyReadonlyArrayDyn, pred: numpy::PyReadonlyArrayDyn, ) -> PyResult<(i128, i128, i128)> @@ -115,36 +91,41 @@ where + num_traits::Num + Into, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - - // TP, TP + FP, 0 - Ok(py.allow_threads(move || { - return (custom_sum(actual * &pred), custom_sum(pred), 0); - })) + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() * elt.1.clone()).into(), + acc.1 + elt.0.clone().into(), + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + } + Ok(reqs) } -fn binary_precision_reqs_owned<'a, T>( - py: Python<'a>, - actual: ndarray::ArrayD, - pred: ndarray::ArrayD, -) -> PyResult<(i128, i128, i128)> -where - T: Clone - + std::marker::Send - + numpy::Element - + std::ops::Add - + num_traits::Num - + Into, -{ - // TP, TP + FP, 0 - Ok(py.allow_threads(move || { - return (custom_sum(actual * &pred), custom_sum(pred), 0); - })) +fn binary_precision_reqs_bool<'a>( + _py: Python<'a>, + actual: numpy::PyReadonlyArrayDyn, + pred: numpy::PyReadonlyArrayDyn, +) -> PyResult<(i128, i128, i128)> { + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() & elt.1.clone()) as i128, + acc.1 + (elt.0.clone()) as i128, + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + } + Ok(reqs) } fn binary_recall_reqs<'a, T>( - py: Python<'a>, + _py: Python<'a>, actual: numpy::PyReadonlyArrayDyn, pred: numpy::PyReadonlyArrayDyn, ) -> PyResult<(i128, i128, i128)> @@ -156,36 +137,41 @@ where + num_traits::Num + Into, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - - // TP, TP + FN, 0 - Ok(py.allow_threads(move || { - return (custom_sum(&actual * pred), custom_sum(actual), 0); - })) + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() * elt.1.clone()).into(), + acc.1 + elt.1.clone().into(), + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + } + Ok(reqs) } -fn binary_recall_reqs_owned<'a, T>( - py: Python<'a>, - actual: ndarray::ArrayD, - pred: ndarray::ArrayD, -) -> PyResult<(i128, i128, i128)> -where - T: Clone - + std::marker::Send - + numpy::Element - + std::ops::Add - + num_traits::Num - + Into, -{ - // TP, TP + FN, 0 - Ok(py.allow_threads(move || { - return (custom_sum(&actual * pred), custom_sum(actual), 0); - })) +fn binary_recall_reqs_bool<'a>( + _py: Python<'a>, + actual: numpy::PyReadonlyArrayDyn, + pred: numpy::PyReadonlyArrayDyn, +) -> PyResult<(i128, i128, i128)> { + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() & elt.1.clone()) as i128, + acc.1 + (elt.1.clone()) as i128, + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + } + Ok(reqs) } fn binary_f1_score_reqs<'a, T>( - py: Python<'a>, + _py: Python<'a>, actual: numpy::PyReadonlyArrayDyn, pred: numpy::PyReadonlyArrayDyn, ) -> PyResult<(i128, i128, i128)> @@ -197,38 +183,39 @@ where + num_traits::Num + Into, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - - // TP, TP + FP, TP + FN - Ok(py.allow_threads(move || { - return ( - custom_sum(&actual * &pred), - custom_sum(pred), - custom_sum(actual), - ); - })) + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() * elt.1.clone()).into(), + acc.1 + elt.0.clone().into(), + acc.2 + elt.1.clone().into(), + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + reqs.2 = reqs.2 + row_reqs.2; + } + Ok(reqs) } -fn binary_f1_score_reqs_owned<'a, T>( - py: Python<'a>, - actual: ndarray::ArrayD, - pred: ndarray::ArrayD, -) -> PyResult<(i128, i128, i128)> -where - T: Clone - + std::marker::Send - + numpy::Element - + std::ops::Add - + num_traits::Num - + Into, -{ - // TP, TP + FP, TP + FN - Ok(py.allow_threads(move || { - return ( - custom_sum(&actual * &pred), - custom_sum(pred), - custom_sum(actual), - ); - })) +fn binary_f1_score_reqs_bool<'a>( + _py: Python<'a>, + actual: numpy::PyReadonlyArrayDyn, + pred: numpy::PyReadonlyArrayDyn, +) -> PyResult<(i128, i128, i128)> { + let mut reqs = (0, 0, 0); + for (r1, r2) in std::iter::zip(pred.as_array().rows(), actual.as_array().rows()) { + let row_reqs = std::iter::zip(r1, r2).fold((0, 0, 0), |acc, elt| { + ( + acc.0 + (elt.0.clone() & elt.1.clone()) as i128, + acc.1 + (elt.0.clone()) as i128, + acc.2 + (elt.1.clone()) as i128, + ) + }); + reqs.0 = reqs.0 + row_reqs.0; + reqs.1 = reqs.1 + row_reqs.1; + reqs.2 = reqs.2 + row_reqs.2; + } + Ok(reqs) } diff --git a/src/cm.rs b/src/cm.rs index 4cac735..a4292f0 100644 --- a/src/cm.rs +++ b/src/cm.rs @@ -1,4 +1,3 @@ -use ndarray::*; use numpy::*; use pyo3::prelude::*; use std::{collections::HashMap, iter::zip}; @@ -27,42 +26,42 @@ pub fn py_confusion_matrix<'a>( ) } -fn confusion_matrix<'a, T>( - py: Python<'a>, +pub fn _confusion_matrix<'a, T>( + _py: Python<'a>, actual: PyReadonlyArrayDyn, pred: PyReadonlyArrayDyn, labels: PyReadonlyArrayDyn, -) -> PyResult<&'a PyArray2> -where - T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, -{ - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - let labels = labels.to_vec().unwrap(); - - let threadable = |actual: ArrayD, pred: ArrayD| -> ndarray::Array2 { - py.allow_threads(move || return confusion_matrix_owned(actual, pred, labels)) - }; - - Ok(PyArray2::from_array(py, &threadable(actual, pred))) -} - -pub fn confusion_matrix_owned( - actual: ndarray::ArrayD, - pred: ndarray::ArrayD, - labels: Vec, ) -> ndarray::Array2 where - T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, + T: Clone + numpy::Element + std::hash::Hash + std::cmp::Eq, { + let labels = labels.to_vec().unwrap(); let mut cm = ndarray::Array2::::from_elem((labels.len(), labels.len()), 0); - let idx_map: HashMap = - HashMap::from_iter(labels.iter().enumerate().map(|(x, y)| (*y, x))); - - for (y_pred, y_actual) in zip(pred.iter(), actual.iter()) { + let idx_map: HashMap = HashMap::from_iter( + labels + .iter() + .enumerate() + .map(|(x, y)| (y.clone(), x.clone())), + ); + for (y_pred, y_actual) in zip(pred.as_array().iter(), actual.as_array().iter()) { if let (Some(ix1), Some(ix2)) = (idx_map.get(y_actual), idx_map.get(y_pred)) { *cm.get_mut((*ix1, *ix2)).unwrap() = *cm.get_mut((*ix1, *ix2)).unwrap() + 1; } } cm } + +pub fn confusion_matrix<'a, T>( + py: Python<'a>, + actual: PyReadonlyArrayDyn, + pred: PyReadonlyArrayDyn, + labels: PyReadonlyArrayDyn, +) -> PyResult<&'a PyArray2> +where + T: Clone + numpy::Element + std::hash::Hash + std::cmp::Eq, +{ + Ok(PyArray2::from_array( + py, + &_confusion_matrix(py, actual, pred, labels), + )) +} diff --git a/src/multiclass.rs b/src/multiclass.rs index 05a37f7..8461a06 100644 --- a/src/multiclass.rs +++ b/src/multiclass.rs @@ -79,25 +79,15 @@ fn precision<'a, T>( where T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - let labels = labels.to_vec().unwrap(); - - let threadable = - |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { - py.allow_threads(move || { - let cm = cm::confusion_matrix_owned(actual, pred, labels); - let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); - for (idx, col) in cm.columns().into_iter().enumerate() { - // get TP - *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); - // get TP + FP - *ret.get_mut((idx, 1)).unwrap() = col.sum(); - } - return ret; - }) - }; - Ok(PyArray2::from_array(py, &threadable(actual, pred))) + let cm = cm::_confusion_matrix(py, actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); + for (idx, col) in cm.columns().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); + // get TP + FP + *ret.get_mut((idx, 1)).unwrap() = col.sum(); + } + Ok(PyArray2::from_array(py, &ret)) } fn recall<'a, T>( @@ -109,25 +99,15 @@ fn recall<'a, T>( where T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - let labels = labels.to_vec().unwrap(); - - let threadable = - |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { - py.allow_threads(move || { - let cm = cm::confusion_matrix_owned(actual, pred, labels); - let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); - for (idx, row) in cm.rows().into_iter().enumerate() { - // get TP - *ret.get_mut((idx, 0)).unwrap() = *row.get(idx).unwrap(); - // get TP + FN - *ret.get_mut((idx, 1)).unwrap() = row.sum(); - } - return ret; - }) - }; - Ok(PyArray2::from_array(py, &threadable(actual, pred))) + let cm = cm::_confusion_matrix(py, actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 2), 0); + for (idx, row) in cm.rows().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *row.get(idx).unwrap(); + // get TP + FN + *ret.get_mut((idx, 1)).unwrap() = row.sum(); + } + Ok(PyArray2::from_array(py, &ret)) } fn f1_score<'a, T>( @@ -139,27 +119,17 @@ fn f1_score<'a, T>( where T: Copy + Clone + std::marker::Send + numpy::Element + std::hash::Hash + std::cmp::Eq, { - let actual = actual.to_owned_array(); - let pred = pred.to_owned_array(); - let labels = labels.to_vec().unwrap(); - - let threadable = - |actual: ndarray::ArrayD, pred: ndarray::ArrayD| -> ndarray::Array2 { - py.allow_threads(move || { - let cm = cm::confusion_matrix_owned(actual, pred, labels); - let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 3), 0); - for (idx, col) in cm.columns().into_iter().enumerate() { - // get TP - *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); - // get TP + FP - *ret.get_mut((idx, 1)).unwrap() = col.sum(); - } - for (idx, row) in cm.rows().into_iter().enumerate() { - // get TP + FN - *ret.get_mut((idx, 2)).unwrap() = row.sum(); - } - return ret; - }) - }; - Ok(PyArray2::from_array(py, &threadable(actual, pred))) + let cm = cm::_confusion_matrix(py, actual, pred, labels); + let mut ret = ndarray::Array2::::from_elem((cm.shape()[0], 3), 0); + for (idx, col) in cm.columns().into_iter().enumerate() { + // get TP + *ret.get_mut((idx, 0)).unwrap() = *col.get(idx).unwrap(); + // get TP + FP + *ret.get_mut((idx, 1)).unwrap() = col.sum(); + } + for (idx, row) in cm.rows().into_iter().enumerate() { + // get TP + FN + *ret.get_mut((idx, 2)).unwrap() = row.sum(); + } + Ok(PyArray2::from_array(py, &ret)) } diff --git a/src/utils.rs b/src/utils.rs index 2228f46..5aa1498 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,27 +15,15 @@ pub fn py_unique<'a>(py: Python<'a>, arr: &'a PyAny) -> PyResult<&'a PySet> { /// ndarray unique fn unique<'a, T>(py: Python<'a>, arr: numpy::PyReadonlyArrayDyn) -> PyResult<&'a PySet> where - T: Clone - + std::marker::Send - + numpy::Element - + std::hash::Hash - + std::cmp::Eq - + pyo3::ToPyObject, + T: Clone + numpy::Element + std::hash::Hash + std::cmp::Eq + pyo3::ToPyObject, { - let arr = arr.to_owned_array(); - - let threadable = |arr: ndarray::ArrayD| -> Vec { - py.allow_threads(move || { - let mut track = HashSet::::new(); - let mut ret: Vec = vec![]; - for val in arr.iter() { - if !track.contains(val) { - track.insert(val.clone()); - ret.push(val.clone()); - } - } - return ret; - }) - }; - PySet::new(py, threadable(arr).as_slice()) + let mut track = HashSet::::new(); + let mut ret: Vec = vec![]; + for val in arr.as_array().iter() { + if !track.contains(val) { + track.insert(val.clone()); + ret.push(val.clone()); + } + } + PySet::new(py, ret.as_slice()) }