diff --git a/nibabel/nifti1.py b/nibabel/nifti1.py index c41b9a8ed3..91ed8a2903 100644 --- a/nibabel/nifti1.py +++ b/nibabel/nifti1.py @@ -898,26 +898,28 @@ def set_data_dtype(self, datatype): >>> hdr.set_data_dtype(np.dtype(np.uint8)) >>> hdr.get_data_dtype() dtype('uint8') - >>> hdr.set_data_dtype('implausible') #doctest: +IGNORE_EXCEPTION_DETAIL + >>> hdr.set_data_dtype('implausible') Traceback (most recent call last): ... - HeaderDataError: data dtype "implausible" not recognized - >>> hdr.set_data_dtype('none') #doctest: +IGNORE_EXCEPTION_DETAIL + nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized + >>> hdr.set_data_dtype('none') Traceback (most recent call last): ... - HeaderDataError: data dtype "none" known but not supported - >>> hdr.set_data_dtype(np.void) #doctest: +IGNORE_EXCEPTION_DETAIL + nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported + >>> hdr.set_data_dtype(np.void) Traceback (most recent call last): ... - HeaderDataError: data dtype "" known but not supported - >>> hdr.set_data_dtype('int') #doctest: +IGNORE_EXCEPTION_DETAIL + nibabel.spatialimages.HeaderDataError: data dtype "" known + but not supported + >>> hdr.set_data_dtype('int') Traceback (most recent call last): ... ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16. - >>> hdr.set_data_dtype(int) #doctest: +IGNORE_EXCEPTION_DETAIL + >>> hdr.set_data_dtype(int) Traceback (most recent call last): ... - ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16. + ValueError: Invalid data type . Specify a sized integer, e.g., 'uint8' or + numpy.int16. >>> hdr.set_data_dtype('int64') >>> hdr.get_data_dtype() == np.dtype('int64') True @@ -1799,6 +1801,10 @@ class Nifti1Pair(analyze.AnalyzeImage): _meta_sniff_len = header_class.sizeof_hdr rw = True + # If a _dtype_alias has been set, it can only be resolved by inspecting + # the data at serialization time + _dtype_alias = None + def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None): # Special carve-out for 64 bit integers @@ -2043,6 +2049,137 @@ def set_sform(self, affine, code=None, **kwargs): else: self._affine[:] = self._header.get_best_affine() + def set_data_dtype(self, datatype): + """ Set numpy dtype for data from code, dtype, type or alias + + Using :py:class:`int` or ``"int"`` is disallowed, as these types + will be interpreted as ``np.int64``, which is almost never desired. + ``np.int64`` is permitted for those intent on making poor choices. + + The following aliases are defined to allow for flexible specification: + + * ``'mask'`` - Alias for ``uint8`` + * ``'compat'`` - The nearest Analyze-compatible datatype + (``uint8``, ``int16``, ``int32``, ``float32``) + * ``'smallest'`` - The smallest Analyze-compatible integer + (``uint8``, ``int16``, ``int32``) + + Dynamic aliases are resolved when ``get_data_dtype()`` is called + with a ``finalize=True`` flag. Until then, these aliases are not + written to the header and will not persist to new images. + + Examples + -------- + >>> ints = np.arange(24, dtype='i4').reshape((2,3,4)) + + >>> img = Nifti1Image(ints, np.eye(4)) + >>> img.set_data_dtype(np.uint8) + >>> img.get_data_dtype() + dtype('uint8') + >>> img.set_data_dtype('mask') + >>> img.get_data_dtype() + dtype('uint8') + >>> img.set_data_dtype('compat') + >>> img.get_data_dtype() + 'compat' + >>> img.get_data_dtype(finalize=True) + dtype('>> img.get_data_dtype() + dtype('>> img.set_data_dtype('smallest') + >>> img.get_data_dtype() + 'smallest' + >>> img.get_data_dtype(finalize=True) + dtype('uint8') + >>> img.get_data_dtype() + dtype('uint8') + + Note that floating point values will not be coerced to ``int`` + + >>> floats = np.arange(24, dtype='f4').reshape((2,3,4)) + >>> img = Nifti1Image(floats, np.eye(4)) + >>> img.set_data_dtype('smallest') + >>> img.get_data_dtype(finalize=True) + Traceback (most recent call last): + ... + ValueError: Cannot automatically cast array (of type float32) to an integer + type with fewer than 64 bits. Please set_data_dtype() to an explicit data type. + + >>> arr = np.arange(1000, 1024, dtype='i4').reshape((2,3,4)) + >>> img = Nifti1Image(arr, np.eye(4)) + >>> img.set_data_dtype('smallest') + >>> img.set_data_dtype('implausible') + Traceback (most recent call last): + ... + nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized + >>> img.set_data_dtype('none') + Traceback (most recent call last): + ... + nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported + >>> img.set_data_dtype(np.void) + Traceback (most recent call last): + ... + nibabel.spatialimages.HeaderDataError: data dtype "" known + but not supported + >>> img.set_data_dtype('int') + Traceback (most recent call last): + ... + ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16. + >>> img.set_data_dtype(int) + Traceback (most recent call last): + ... + ValueError: Invalid data type . Specify a sized integer, e.g., 'uint8' or + numpy.int16. + >>> img.set_data_dtype('int64') + >>> img.get_data_dtype() == np.dtype('int64') + True + """ + # Comparing dtypes to strings, numpy will attempt to call, e.g., dtype('mask'), + # so only check for aliases if the type is a string + # See https://github.com/numpy/numpy/issues/7242 + if isinstance(datatype, str): + # Static aliases + if datatype == 'mask': + datatype = 'u1' + # Dynamic aliases + elif datatype in ('compat', 'smallest'): + self._dtype_alias = datatype + return + + self._dtype_alias = None + super().set_data_dtype(datatype) + + def get_data_dtype(self, finalize=False): + """ Get numpy dtype for data + + If ``set_data_dtype()`` has been called with an alias + and ``finalize`` is ``False``, return the alias. + If ``finalize`` is ``True``, determine the appropriate dtype + from the image data object and set the final dtype in the + header before returning it. + """ + if self._dtype_alias is None: + return super().get_data_dtype() + if not finalize: + return self._dtype_alias + + datatype = None + if self._dtype_alias == 'compat': + datatype = _get_analyze_compat_dtype(self._dataobj) + descrip = "an Analyze-compatible dtype" + elif self._dtype_alias == 'smallest': + datatype = _get_smallest_dtype(self._dataobj) + descrip = "an integer type with fewer than 64 bits" + else: + raise ValueError(f"Unknown dtype alias {self._dtype_alias}.") + if datatype is None: + dt = get_obj_dtype(self._dataobj) + raise ValueError(f"Cannot automatically cast array (of type {dt}) to {descrip}." + " Please set_data_dtype() to an explicit data type.") + + self.set_data_dtype(datatype) # Clears the alias + return super().get_data_dtype() + def as_reoriented(self, ornt): """Apply an orientation change and return a new image @@ -2136,3 +2273,141 @@ def save(img, filename): Nifti1Image.instance_to_filename(img, filename) except ImageFileError: Nifti1Pair.instance_to_filename(img, filename) + + +def _get_smallest_dtype( + arr, + itypes=(np.uint8, np.int16, np.int32), + ftypes=(), + ): + """ Return the smallest "sensible" dtype that will hold the array data + + The purpose of this function is to support automatic type selection + for serialization, so "sensible" here means well-supported in the NIfTI-1 world. + + For floating point data, select between single- and double-precision. + For integer data, select among uint8, int16 and int32. + + The test is for min/max range, so float64 is pretty unlikely to be hit. + + Returns ``None`` if these dtypes do not suffice. + + >>> _get_smallest_dtype(np.array([0, 1])) + dtype('uint8') + >>> _get_smallest_dtype(np.array([-1, 1])) + dtype('int16') + >>> _get_smallest_dtype(np.array([0, 256])) + dtype('int16') + >>> _get_smallest_dtype(np.array([-65536, 65536])) + dtype('int32') + >>> _get_smallest_dtype(np.array([-2147483648, 2147483648])) + + By default floating point types are not searched: + + >>> _get_smallest_dtype(np.array([1.])) + >>> _get_smallest_dtype(np.array([2. ** 1000])) + >>> _get_smallest_dtype(np.longdouble(2) ** 2000) + >>> _get_smallest_dtype(np.array([1+0j])) + + However, this function can be passed "legal" floating point types, and + the logic works the same. + + >>> _get_smallest_dtype(np.array([1.]), ftypes=('float32',)) + dtype('float32') + >>> _get_smallest_dtype(np.array([2. ** 1000]), ftypes=('float32',)) + >>> _get_smallest_dtype(np.longdouble(2) ** 2000, ftypes=('float32',)) + >>> _get_smallest_dtype(np.array([1+0j]), ftypes=('float32',)) + """ + arr = np.asanyarray(arr) + if np.issubdtype(arr.dtype, np.floating): + test_dts = ftypes + info = np.finfo + elif np.issubdtype(arr.dtype, np.integer): + test_dts = itypes + info = np.iinfo + else: + return None + + mn, mx = np.min(arr), np.max(arr) + for dt in test_dts: + dtinfo = info(dt) + if dtinfo.min <= mn and mx <= dtinfo.max: + return np.dtype(dt) + + +def _get_analyze_compat_dtype(arr): + """ Return an Analyze-compatible dtype that ``arr`` can be safely cast to + + Analyze-compatible types are returned without inspection: + + >>> _get_analyze_compat_dtype(np.uint8([0, 1])) + dtype('uint8') + >>> _get_analyze_compat_dtype(np.int16([0, 1])) + dtype('int16') + >>> _get_analyze_compat_dtype(np.int32([0, 1])) + dtype('int32') + >>> _get_analyze_compat_dtype(np.float32([0, 1])) + dtype('float32') + + Signed ``int8`` are cast to ``uint8`` or ``int16`` based on value ranges: + + >>> _get_analyze_compat_dtype(np.int8([0, 1])) + dtype('uint8') + >>> _get_analyze_compat_dtype(np.int8([-1, 1])) + dtype('int16') + + Unsigned ``uint16`` are cast to ``int16`` or ``int32`` based on value ranges: + + >>> _get_analyze_compat_dtype(np.uint16([32767])) + dtype('int16') + >>> _get_analyze_compat_dtype(np.uint16([65535])) + dtype('int32') + + ``int32`` is returned for integer types and ``float32`` for floating point types: + + >>> _get_analyze_compat_dtype(np.array([-1, 1])) + dtype('int32') + >>> _get_analyze_compat_dtype(np.array([-1., 1.])) + dtype('float32') + + If the value ranges exceed 4 bytes or cannot be cast, then a ``ValueError`` is raised: + + >>> _get_analyze_compat_dtype(np.array([0, 4294967295])) + Traceback (most recent call last): + ... + ValueError: Cannot find analyze-compatible dtype for array with dtype=int64 + (min=0, max=4294967295) + + >>> _get_analyze_compat_dtype([0., 2.e40]) + Traceback (most recent call last): + ... + ValueError: Cannot find analyze-compatible dtype for array with dtype=float64 + (min=0.0, max=2e+40) + + Note that real-valued complex arrays cannot be safely cast. + + >>> _get_analyze_compat_dtype(np.array([1+0j])) + Traceback (most recent call last): + ... + ValueError: Cannot find analyze-compatible dtype for array with dtype=complex128 + (min=(1+0j), max=(1+0j)) + """ + arr = np.asanyarray(arr) + dtype = arr.dtype + if dtype in (np.uint8, np.int16, np.int32, np.float32): + return dtype + + if dtype == np.int8: + return np.dtype('uint8' if arr.min() >= 0 else 'int16') + elif dtype == np.uint16: + return np.dtype('int16' if arr.max() <= np.iinfo(np.int16).max else 'int32') + + mn, mx = arr.min(), arr.max() + if np.can_cast(mn, np.int32) and np.can_cast(mx, np.int32): + return np.dtype('int32') + if np.can_cast(mn, np.float32) and np.can_cast(mx, np.float32): + return np.dtype('float32') + + raise ValueError( + f"Cannot find analyze-compatible dtype for array with dtype={dtype} (min={mn}, max={mx})" + ) diff --git a/nibabel/tests/test_nifti1.py b/nibabel/tests/test_nifti1.py index 7652f77e42..8ed897b036 100644 --- a/nibabel/tests/test_nifti1.py +++ b/nibabel/tests/test_nifti1.py @@ -1119,6 +1119,63 @@ def test_write_scaling(self): with np.errstate(invalid='ignore'): self._check_write_scaling(slope, inter, e_slope, e_inter) + def test_dynamic_dtype_aliases(self): + for in_dt, mn, mx, alias, effective_dt in [ + (np.uint8, 0, 255, 'compat', np.uint8), + (np.int8, 0, 127, 'compat', np.uint8), + (np.int8, -128, 127, 'compat', np.int16), + (np.int16, -32768, 32767, 'compat', np.int16), + (np.uint16, 0, 32767, 'compat', np.int16), + (np.uint16, 0, 65535, 'compat', np.int32), + (np.int32, -2**31, 2**31-1, 'compat', np.int32), + (np.uint32, 0, 2**31-1, 'compat', np.int32), + (np.uint32, 0, 2**32-1, 'compat', None), + (np.int64, -2**31, 2**31-1, 'compat', np.int32), + (np.uint64, 0, 2**31-1, 'compat', np.int32), + (np.int64, 0, 2**32-1, 'compat', None), + (np.uint64, 0, 2**32-1, 'compat', None), + (np.float32, 0, 1e30, 'compat', np.float32), + (np.float64, 0, 1e30, 'compat', np.float32), + (np.float64, 0, 1e40, 'compat', None), + (np.int64, 0, 255, 'smallest', np.uint8), + (np.int64, 0, 256, 'smallest', np.int16), + (np.int64, -1, 255, 'smallest', np.int16), + (np.int64, 0, 32768, 'smallest', np.int32), + (np.int64, 0, 4294967296, 'smallest', None), + (np.float32, 0, 1, 'smallest', None), + (np.float64, 0, 1, 'smallest', None) + ]: + arr = np.arange(24, dtype=in_dt).reshape((2, 3, 4)) + arr[0, 0, :2] = [mn, mx] + img = self.image_class(arr, np.eye(4), dtype=alias) + # Stored as alias + assert img.get_data_dtype() == alias + if effective_dt is None: + with pytest.raises(ValueError): + img.get_data_dtype(finalize=True) + continue + # Finalizing sets and clears the alias + assert img.get_data_dtype(finalize=True) == effective_dt + assert img.get_data_dtype() == effective_dt + # Re-set to alias + img.set_data_dtype(alias) + assert img.get_data_dtype() == alias + img_rt = bytesio_round_trip(img) + assert img_rt.get_data_dtype() == effective_dt + # Seralizing does not finalize the source image + assert img.get_data_dtype() == alias + + def test_static_dtype_aliases(self): + for alias, effective_dt in [ + ("mask", np.uint8), + ]: + for orig_dt in ('u1', 'i8', 'f4'): + arr = np.arange(24, dtype=orig_dt).reshape((2, 3, 4)) + img = self.image_class(arr, np.eye(4), dtype=alias) + assert img.get_data_dtype() == effective_dt + img_rt = bytesio_round_trip(img) + assert img_rt.get_data_dtype() == effective_dt + class TestNifti1Image(TestNifti1Pair): # Run analyze-flavor spatialimage tests