Skip to content

Commit

Permalink
Macro Refactoring (#11)
Browse files Browse the repository at this point in the history
* updating project metadata

* fixing CI yaml

* using venv

* trying again

* redoing CI

* fixing tests

* changing some settings

* updates

* fixing build

* trying to fix this

* fixing release

* bumping version

* better code organization

* updates

* adding initial cm impl

* adding unique, dispatching pattern to Py objects, renaming ext

* rustfmt

* cm dispatched

* rustfmt

* tests and benchmarks added

* bump version

* 100% test coverage

* updating readme

* Threading enabled (#9)

* bumping version

* major refactor leveraging macros

* bumping version and updating test

* adding executed notebook

* fixing performance w/ bool
  • Loading branch information
zachcoleman authored May 29, 2022
1 parent 0bcfc1c commit 6fee6d7
Show file tree
Hide file tree
Showing 12 changed files with 580 additions and 546 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fast-stats"
version = "0.0.5"
version = "0.0.6"
edition = "2021"

[lib]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The project was developed using the [maturin](https://maturin.rs) framework.
This project is still in development.

## Installation
From PyPi:
From PyPI:
```shell
pip install fast-stats
```
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/threading_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"SIZE = (10, 2048, 2048)\n",
"NUM_CATS = 20"
"SIZE = (10, 2048, 2048)"
]
},
{
Expand Down Expand Up @@ -62,7 +61,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"471 ms ± 7.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"462 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -81,7 +80,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"260 ms ± 5.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"258 ms ± 3.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down
30 changes: 15 additions & 15 deletions benchmarks/timeit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"569 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"580 ms ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -84,7 +84,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"569 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"581 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -102,7 +102,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"577 ms ± 8.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"579 ms ± 5.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -131,7 +131,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.13 ms ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"3.27 ms ± 45.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand All @@ -149,7 +149,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.11 ms ± 19.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"3.12 ms ± 46.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand All @@ -167,7 +167,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"4.34 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"4.35 ms ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand Down Expand Up @@ -221,7 +221,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"5.84 ms ± 96.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"5.7 ms ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand All @@ -239,7 +239,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.2 ms ± 77.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"3.17 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand Down Expand Up @@ -274,7 +274,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"331 ms ± 608 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"316 ms ± 608 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -292,7 +292,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"177 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"169 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand Down Expand Up @@ -325,7 +325,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"135 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"142 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand All @@ -343,7 +343,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"71.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"72 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand All @@ -352,7 +352,7 @@
"# while labels is optional argument providing\n",
"# labels will lead to a significant speedup\n",
"# since it will not have to be inferred\n",
"_ = fast_stats.confusion_matrix(y_true, y_pred, labels=list(range(NUM_CATS)))"
"_ = fast_stats.confusion_matrix(y_true, y_pred, labels = list(range(NUM_CATS)))"
]
},
{
Expand Down Expand Up @@ -392,7 +392,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.77 ms ± 83.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
"1.69 ms ± 32.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
Expand All @@ -410,7 +410,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"316 ns ± 1.43 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n"
"326 ns ± 0.755 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n"
]
}
],
Expand Down
10 changes: 7 additions & 3 deletions fast_stats/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List
from typing import List, Union

import numpy as np

from ._fast_stats_ext import _confusion_matrix, _unique


def confusion_matrix(
y_true: np.ndarray, y_pred: np.ndarray, labels: List = None
y_true: np.ndarray, y_pred: np.ndarray, labels: Union[List, np.ndarray] = None
) -> np.ndarray:
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert all(
Expand All @@ -17,6 +17,10 @@ def confusion_matrix(
), "y_true and y_pred must be numpy arrays"

if labels is None:
labels = sorted(list(_unique(np.concatenate([y_true, y_pred]))))
labels = np.array(
sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype
)
elif isinstance(labels, list):
labels = np.array(labels, dtype=y_true.dtype)

return _confusion_matrix(y_true, y_pred, labels)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fast-stats"
version = "0.0.5"
version = "0.0.6"
description = "A fast and simple library for calculating basic statistics"
readme = "README.md"
license = {text="Apache 2.0"}
Expand Down
Loading

0 comments on commit 6fee6d7

Please sign in to comment.