Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(chore) type hints for tests #698

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions numcodecs/tests/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from __future__ import annotations

import array
import json as _json
import os
from glob import glob
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence

Check warning on line 10 in numcodecs/tests/common.py

View check run for this annotation

Codecov / codecov/patch

numcodecs/tests/common.py#L10

Added line #L10 was not covered by tests

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal, assert_array_equal

from numcodecs import * # noqa: F403 # for eval to find names in repr tests
from numcodecs.abc import Codec
from numcodecs.compat import ensure_bytes, ensure_ndarray
from numcodecs.registry import get_codec

Expand All @@ -27,7 +34,7 @@
]


def compare_arrays(arr, res, precision=None):
def compare_arrays(arr: np.ndarray, res: np.ndarray, precision: int | None = None) -> None:
# ensure numpy array with matching dtype
res = ensure_ndarray(res).view(arr.dtype)

Expand All @@ -47,7 +54,7 @@
assert_array_almost_equal(arr, res, decimal=precision)


def check_encode_decode(arr, codec, precision=None):
def check_encode_decode(arr: np.ndarray, codec: Codec, precision: int | None = None) -> None:
# N.B., watch out here with blosc compressor, if the itemsize of
# the source buffer is different then the results of encoding
# (i.e., compression) may be different. Hence we *do not* require that
Expand Down Expand Up @@ -115,7 +122,9 @@
compare_arrays(arr, out, precision=precision)


def check_encode_decode_partial(arr, codec, precision=None):
def check_encode_decode_partial(
arr: np.ndarray, codec: Codec, precision: int | None = None
) -> None:
# N.B., watch out here with blosc compressor, if the itemsize of
# the source buffer is different then the results of encoding
# (i.e., compression) may be different. Hence we *do not* require that
Expand Down Expand Up @@ -183,7 +192,7 @@
compare_arrays(compare_arr, out, precision=precision)


def assert_array_items_equal(res, arr):
def assert_array_items_equal(res: np.ndarray, arr: np.ndarray) -> None:
assert isinstance(res, np.ndarray)
res = res.reshape(-1, order='A')
arr = arr.reshape(-1, order='A')
Expand All @@ -204,7 +213,7 @@
assert a == r


def check_encode_decode_array(arr, codec):
def check_encode_decode_array(arr: np.ndarray, codec: Codec) -> None:
enc = codec.encode(arr)
dec = codec.decode(enc)
assert_array_items_equal(arr, dec)
Expand All @@ -218,7 +227,7 @@
assert_array_items_equal(arr, dec)


def check_encode_decode_array_to_bytes(arr, codec):
def check_encode_decode_array_to_bytes(arr: np.ndarray, codec: Codec) -> None:
enc = codec.encode(arr)
dec = codec.decode(enc)
assert_array_items_equal(arr, dec)
Expand All @@ -228,21 +237,27 @@
assert_array_items_equal(arr, out)


def check_config(codec):
def check_config(codec: Codec) -> None:
config = codec.get_config()
# round-trip through JSON to check serialization
config = _json.loads(_json.dumps(config))
assert codec == get_codec(config)


def check_repr(stmt):
def check_repr(stmt: str) -> None:
# check repr matches instantiation statement
codec = eval(stmt)
actual = repr(codec)
assert stmt == actual


def check_backwards_compatibility(codec_id, arrays, codecs, precision=None, prefix=None):
def check_backwards_compatibility(
codec_id: str,
arrays: Sequence[np.ndarray],
codecs: Sequence[Codec],
precision: Sequence[int | None] | None = None,
prefix: str | None = None,
) -> None:
# setup directory to hold data fixture
if prefix:
fixture_dir = os.path.join('fixture', codec_id, prefix)
Expand Down Expand Up @@ -312,7 +327,7 @@
assert arr_bytes == ensure_bytes(dec)


def check_err_decode_object_buffer(compressor):
def check_err_decode_object_buffer(compressor: Codec) -> None:
# cannot decode directly into object array, leads to segfaults
a = np.arange(10)
enc = compressor.encode(a)
Expand All @@ -321,14 +336,14 @@
compressor.decode(enc, out=out)


def check_err_encode_object_buffer(compressor):
def check_err_encode_object_buffer(compressor: Codec) -> None:
# compressors cannot encode object array
a = np.array(['foo', 'bar', 'baz'], dtype=object)
with pytest.raises(TypeError):
compressor.encode(a)


def check_max_buffer_size(codec):
def check_max_buffer_size(codec: Codec) -> None:
for max_buffer_size in (4, 64, 1024):
old_max_buffer_size = codec.max_buffer_size
try:
Expand Down
12 changes: 6 additions & 6 deletions numcodecs/tests/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
]


def test_encode_decode():
def test_encode_decode() -> None:
for arr in arrays:
codec = AsType(encode_dtype=arr.dtype, decode_dtype=arr.dtype)
check_encode_decode(arr, codec)


def test_decode():
def test_decode() -> None:
encode_dtype, decode_dtype = '<i4', '<i8'
codec = AsType(encode_dtype=encode_dtype, decode_dtype=decode_dtype)
arr = np.arange(10, 20, 1, dtype=encode_dtype)
Expand All @@ -36,7 +36,7 @@ def test_decode():
assert np.dtype(decode_dtype) == actual.dtype


def test_encode():
def test_encode() -> None:
encode_dtype, decode_dtype = '<i4', '<i8'
codec = AsType(encode_dtype=encode_dtype, decode_dtype=decode_dtype)
arr = np.arange(10, 20, 1, dtype=decode_dtype)
Expand All @@ -46,17 +46,17 @@ def test_encode():
assert np.dtype(encode_dtype) == actual.dtype


def test_config():
def test_config() -> None:
encode_dtype, decode_dtype = '<i4', '<i8'
codec = AsType(encode_dtype=encode_dtype, decode_dtype=decode_dtype)
check_config(codec)


def test_repr():
def test_repr() -> None:
check_repr("AsType(encode_dtype='<i4', decode_dtype='<i2')")


def test_backwards_compatibility():
def test_backwards_compatibility() -> None:
# integers
arrs = [
np.arange(1000, dtype='<i4'),
Expand Down
16 changes: 8 additions & 8 deletions numcodecs/tests/test_base64.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,43 @@
]


def test_encode_decode():
def test_encode_decode() -> None:
for arr, codec in itertools.product(arrays, codecs):
check_encode_decode(arr, codec)


def test_repr():
def test_repr() -> None:
check_repr("Base64()")


def test_eq():
def test_eq() -> None:
assert Base64() == Base64()
assert not Base64() != Base64()
assert Base64() != "foo"
assert "foo" != Base64()
assert not Base64() == "foo"


def test_backwards_compatibility():
def test_backwards_compatibility() -> None:
check_backwards_compatibility(Base64.codec_id, arrays, codecs)


def test_err_decode_object_buffer():
def test_err_decode_object_buffer() -> None:
check_err_decode_object_buffer(Base64())


def test_err_encode_object_buffer():
def test_err_encode_object_buffer() -> None:
check_err_encode_object_buffer(Base64())


def test_err_encode_list():
def test_err_encode_list() -> None:
data = ["foo", "bar", "baz"]
for codec in codecs:
with pytest.raises(TypeError):
codec.encode(data)


def test_err_encode_non_contiguous():
def test_err_encode_non_contiguous() -> None:
# non-contiguous memory
arr = np.arange(1000, dtype="i4")[::2]
for codec in codecs:
Expand Down
18 changes: 9 additions & 9 deletions numcodecs/tests/test_bitround.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@

# TODO: add other dtypes
@pytest.fixture(params=["float32", "float64"])
def dtype(request):
def dtype(request: pytest.FixtureRequest) -> str:
return request.param


def round(data, keepbits):
def round(data: np.ndarray, keepbits: int) -> np.ndarray:
codec = BitRound(keepbits=keepbits)
data = data.copy() # otherwise overwrites the input
encoded = codec.encode(data)
return codec.decode(encoded)


def test_round_zero_to_zero(dtype):
def test_round_zero_to_zero(dtype: str) -> None:
a = np.zeros((3, 2), dtype=dtype)
# Don't understand Milan's original test:
# How is it possible to have negative keepbits?
Expand All @@ -29,21 +29,21 @@ def test_round_zero_to_zero(dtype):
np.testing.assert_equal(a, ar)


def test_round_one_to_one(dtype):
def test_round_one_to_one(dtype: str) -> None:
a = np.ones((3, 2), dtype=dtype)
for k in range(max_bits[dtype]):
ar = round(a, k)
np.testing.assert_equal(a, ar)


def test_round_minus_one_to_minus_one(dtype):
def test_round_minus_one_to_minus_one(dtype: str) -> None:
a = -np.ones((3, 2), dtype=dtype)
for k in range(max_bits[dtype]):
ar = round(a, k)
np.testing.assert_equal(a, ar)


def test_no_rounding(dtype):
def test_no_rounding(dtype: str) -> None:
a = np.random.random_sample((300, 200)).astype(dtype)
keepbits = max_bits[dtype]
ar = round(a, keepbits)
Expand All @@ -53,7 +53,7 @@ def test_no_rounding(dtype):
APPROX_KEEPBITS = {"float32": 11, "float64": 18}


def test_approx_equal(dtype):
def test_approx_equal(dtype: str) -> None:
a = np.random.random_sample((300, 200)).astype(dtype)
ar = round(a, APPROX_KEEPBITS[dtype])
# Mimic julia behavior - https://docs.julialang.org/en/v1/base/math/#Base.isapprox
Expand All @@ -64,15 +64,15 @@ def test_approx_equal(dtype):
np.testing.assert_allclose(a, ar, rtol=rtol)


def test_idempotence(dtype):
def test_idempotence(dtype: str) -> None:
a = np.random.random_sample((300, 200)).astype(dtype)
for k in range(20):
ar = round(a, k)
ar2 = round(a, k)
np.testing.assert_equal(ar, ar2)


def test_errors():
def test_errors() -> None:
with pytest.raises(ValueError):
BitRound(keepbits=99).encode(np.array([0], dtype="float32"))
with pytest.raises(TypeError):
Expand Down
Loading
Loading