Skip to content

Commit

Permalink
added postprocess func args and kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
martibosch committed Jul 5, 2021
1 parent 5fc5e1e commit cea8523
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
20 changes: 14 additions & 6 deletions detectree/lidar.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ def to_canopy_mask(
lidar_tree_values,
ref_img_filepath,
*,
output_filepath=None,
postprocess_func=None,
output_filepath=None
postprocess_func_args=None,
postprocess_func_kws=None
):
"""
Transform a LiDAR file into a canopy mask.
Expand All @@ -128,15 +130,19 @@ def to_canopy_mask(
LiDAR point classes that correspond to trees
ref_img_filepath : str, file object or pathlib.Path object
Reference raster image to which the LiDAR data will be rasterized
postprocess_func : function
Post-processing function which takes as input the rasterized lidar
as a boolean ndarray and returns a the post-processed lidar also as
a boolean ndarray.
output_filepath : str, file object or pathlib.Path object, optional
Path to a file, URI, file object opened in binary ('rb') mode, or
a Path object representing where the predicted image is to be
dumped. The value will be passed to `rasterio.open` in 'write'
mode.
postprocess_func : function
Post-processing function which takes as input the rasterized lidar
as a boolean ndarray and returns a the post-processed lidar also as
a boolean ndarray.
postprocess_func_args : list-like, optional
Arguments to be passed to `postprocess_func`.
postprocess_func_kws : dict, optional
Keyword arguments to be passed to `postprocess_func`.
Returns
-------
Expand All @@ -153,7 +159,9 @@ def to_canopy_mask(
)
canopy_arr = lidar_arr >= self.tree_threshold
if postprocess_func is not None:
canopy_arr = postprocess_func(canopy_arr)
canopy_arr = postprocess_func(
canopy_arr, *postprocess_func_args, **postprocess_func_kws
)
canopy_arr = np.where(
canopy_arr, self.output_tree_val, self.output_nodata
).astype(self.output_dtype)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_detectree.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,12 +697,15 @@ def test_lidar_to_canopy(self):
np.ndarray,
)

# test that we can pass a `postprocess_func` to `to_canopy_mask`
# test that we can pass a `postprocess_func` with args/kwargs to
# `to_canopy_mask`
y_pred = ltc.to_canopy_mask(
self.lidar_filepath,
self.lidar_tree_values,
self.ref_img_filepath,
postprocess_func=ndi.binary_dilation,
postprocess_func_args=[ndi.generate_binary_structure(2, 2)],
postprocess_func_kws={"border_value": 0},
)
# test that `to_canopy_mask` with `output_filepath` returns a ndarray
# and dumps it
Expand Down

0 comments on commit cea8523

Please sign in to comment.