From 1ac101d8d77e7c177111911ef9a8ee6d30da7b06 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Fri, 30 Aug 2024 10:14:35 -0600 Subject: [PATCH] Use mask as alpha when reprojecting a masked array (#3156) * Use mask as alpha when reprojecting a masked array Resolves #2575 * Remove unnecessary nodata wranging * Masked reprojection results depend on GDAL version * Fix test expectations and mask scaling value The tests of this PR are now only run for GDAL 3.8+ because the outputs are too variable. * Update change log --- CHANGES.txt | 3 ++ rasterio/_io.pxd | 2 + rasterio/_warp.pyx | 114 +++++++++++++++++++++++++++++++-------------- rasterio/warp.py | 9 +++- tests/test_warp.py | 58 ++++++++++++++++++++++- 5 files changed, 149 insertions(+), 37 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 3feb2b92..c3647614 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -6,6 +6,9 @@ Next (TBD) Bug fixes: +- When reprojecting a masked array, we now use the mask (reduced) as an alpha + band. There is now also an option to create an alpha band in the output, and + turn that into a mask when returning a mask array (#3156). - Find installed GDAL data directory by searching for gdalvrt.xsd (#3157). - Allow rasterio.open() to receive instances of MemoryFile (#3145). - Leaks of CSL string lists in get/set_proj_data_search_path() have been fixed diff --git a/rasterio/_io.pxd b/rasterio/_io.pxd index 43f04148..e022ab70 100644 --- a/rasterio/_io.pxd +++ b/rasterio/_io.pxd @@ -42,3 +42,5 @@ ctypedef np.float64_t DTYPE_FLOAT64_t cdef bint in_dtype_range(value, dtype) cdef int io_auto(image, GDALRasterBandH band, bint write, int resampling=*) except -1 +cdef int io_band(GDALRasterBandH band, int mode, double x0, double y0, double width, double height, object data, int resampling=*) except -1 +cdef int io_multi_band(GDALDatasetH hds, int mode, double x0, double y0, double width, double height, object data, Py_ssize_t[:] indexes, int resampling=*) except -1 diff --git a/rasterio/_warp.pyx b/rasterio/_warp.pyx index 8efbd253..15eea7cf 100644 --- a/rasterio/_warp.pyx +++ b/rasterio/_warp.pyx @@ -38,7 +38,7 @@ from libc.math cimport HUGE_VAL from rasterio._base cimport get_driver_name from rasterio._err cimport exc_wrap, exc_wrap_pointer, exc_wrap_int, StackChecker from rasterio._io cimport ( - DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto) + DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto, io_band, io_multi_band) from rasterio._features cimport GeomBuilder, OGRGeomBuilder from rasterio.crs cimport CRS @@ -282,9 +282,9 @@ def _reproject( nodata value of the destination image (if set), the value of src_nodata, or 0 (gdal default). src_alpha : int, optional - Index of a band to use as the alpha band when warping. + Index of a band to use as the source alpha band when warping. dst_alpha : int, optional - Index of a band to use as the alpha band when warping. + Index of a band to use as the destination alpha band when warping. resampling : int Resampling method to use. One of the following: Resampling.nearest, @@ -360,35 +360,54 @@ def _reproject( cdef MemoryDataset src_mem = None try: - # If the source is an ndarray, we copy to a MEM dataset. # We need a src_transform and src_dst in this case. These will # be copied to the MEM dataset. if dtypes.is_ndarray(source): if not src_crs: raise CRSError("Missing src_crs.") - if src_nodata is None and hasattr(source, 'fill_value'): - # source is a masked array - src_nodata = source.fill_value + # ensure data converted to numpy array - source = np.asanyarray(source) + if hasattr(source, "mask"): + source = np.ma.asanyarray(source) + else: + source = np.asanyarray(source) + # Convert 2D single-band arrays to 3D multi-band. if len(source.shape) == 2: source = source.reshape(1, *source.shape) + src_count = source.shape[0] src_bidx = range(1, src_count + 1) - src_mem = MemoryDataset(source, - transform=format_transform(src_transform), - gcps=gcps, - rpcs=rpcs, - crs=src_crs, - copy=True) + + if hasattr(source, "mask"): + mask = ~np.logical_or.reduce(source.mask) * np.uint8(255) + source_arr = np.concatenate((source.data, [mask])) + src_alpha = src_alpha or source_arr.shape[0] + else: + source_arr = source + + src_mem = MemoryDataset( + source_arr, + transform=format_transform(src_transform), + gcps=gcps, + rpcs=rpcs, + crs=src_crs, + copy=True, + ) src_dataset = src_mem.handle() + if src_alpha: + for i in range(source_arr.shape[0]): + GDALDeleteRasterNoDataValue(GDALGetRasterBand(src_dataset, i+1)) + GDALSetRasterColorInterpretation( + GDALGetRasterBand(src_dataset, src_alpha), + 6, + ) + # If the source is a rasterio MultiBand, no copy necessary. # A MultiBand is a tuple: (dataset, bidx, dtype, shape(2d)) elif isinstance(source, tuple): - rdr, src_bidx, dtype, shape = source if isinstance(src_bidx, int): src_bidx = [src_bidx] @@ -408,12 +427,17 @@ def _reproject( # Next, do the same for the destination raster. try: - if dtypes.is_ndarray(destination): if not dst_crs: raise CRSError("Missing dst_crs.") - # ensure data converted to numpy array - destination = np.asanyarray(destination) + + dst_nodata = dst_nodata or src_nodata + + if hasattr(destination, "mask"): + destination = np.ma.asanyarray(destination) + else: + destination = np.asanyarray(destination) + if len(destination.shape) == 2: destination = destination.reshape(1, *destination.shape) @@ -426,27 +450,33 @@ def _reproject( raise ValueError("Invalid destination shape") dst_bidx = src_bidx - mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs) + if hasattr(destination, "mask"): + count, height, width = destination.shape + msk = np.logical_or.reduce(destination.mask) + if msk == np.ma.nomask: + msk = np.zeros((height, width), dtype="bool") + msk = ~msk * np.uint8(255) + dest_arr = np.concatenate((destination.data, [msk])) + dst_alpha = dst_alpha or dest_arr.shape[0] + else: + dest_arr = destination + + mem_raster = MemoryDataset(dest_arr, transform=format_transform(dst_transform), crs=dst_crs) dst_dataset = mem_raster.handle() if dst_alpha: - for i in range(destination.shape[0]): + for i in range(dest_arr.shape[0]): GDALDeleteRasterNoDataValue(GDALGetRasterBand(dst_dataset, i+1)) - - GDALSetRasterColorInterpretation(GDALGetRasterBand(dst_dataset, dst_alpha), 6) + GDALSetRasterColorInterpretation( + GDALGetRasterBand(dst_dataset, dst_alpha), + 6, + ) GDALSetDescription( dst_dataset, "Temporary destination dataset for _reproject()") log.debug("Created temp destination dataset.") - if dst_nodata is None: - if hasattr(destination, "fill_value"): - # destination is a masked array - dst_nodata = destination.fill_value - elif src_nodata is not None: - dst_nodata = src_nodata - elif isinstance(destination, tuple): udr, dst_bidx, _, _ = destination if isinstance(dst_bidx, int): @@ -609,8 +639,27 @@ def _reproject( except CPLE_BaseError as base: raise WarpOperationError("Chunk and warp failed") from base - if dtypes.is_ndarray(destination): - exc_wrap_int(io_auto(destination, dst_dataset, 0)) + if mem_raster is not None: + count, height, width = dest_arr.shape + if hasattr(destination, "mask"): + # Pick off the alpha band and make a mask of it. + # TODO: do this efficiently, not copying unless necessary. + indexes = np.arange(1, count, dtype='intp') + io_multi_band(dst_dataset, 0, 0.0, 0.0, width, height, destination, indexes) + alpha_arr = np.empty((height, width), dtype=dest_arr.dtype) + io_band(mem_raster.band(count), 0, 0.0, 0.0, width, height, alpha_arr) + destination = np.ma.masked_array( + destination.data, + mask=np.repeat( + ~(alpha_arr.astype("bool"))[np.newaxis, :, :], + count - 1, + axis=0, + ) + ) + else: + exc_wrap_int(io_auto(destination, dst_dataset, 0)) + + return destination # Clean up transformer, warp options, and dataset handles. finally: @@ -629,7 +678,6 @@ def _reproject( if src_mem is not None: src_mem.close() - def _calculate_default_transform( src_crs, dst_crs, @@ -1024,11 +1072,9 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase): # raise an exception instead. if add_alpha: - if src_alpha_band: raise WarpOptionsError( "The VRT already has an alpha band, adding a new one is not supported") - else: dst_alpha_band = src_dataset.count + 1 self.dst_nodata = None diff --git a/rasterio/warp.py b/rasterio/warp.py index c9fedd47..7697114d 100644 --- a/rasterio/warp.py +++ b/rasterio/warp.py @@ -172,6 +172,7 @@ def reproject( dst_resolution=None, src_alpha=0, dst_alpha=0, + masked=False, resampling=Resampling.nearest, num_threads=1, init_dest_nodata=True, @@ -250,6 +251,8 @@ def reproject( Index of a band to use as the alpha band when warping. dst_alpha : int, optional Index of a band to use as the alpha band when warping. + masked: bool, optional + If True and destination is None, return a masked array. resampling: int, rasterio.enums.Resampling Resampling method to use. Default is :attr:`rasterio.enums.Resampling.nearest`. @@ -378,8 +381,10 @@ def reproject( destination = np.empty( (int(dst_count), int(dst_height), int(dst_width)), dtype=source.dtype ) + if masked: + destination = np.ma.masked_array(destination).filled(dst_nodata) - _reproject( + dest = _reproject( source, destination, src_transform=src_transform, @@ -400,7 +405,7 @@ def reproject( **kwargs ) - return destination, dst_transform + return dest, dst_transform def aligned_target(transform, width, height, resolution): diff --git a/tests/test_warp.py b/tests/test_warp.py index ead43e39..2ea643d0 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -1333,7 +1333,25 @@ def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif): assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero -@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)]) +@pytest.mark.parametrize( + "test3d,count_nonzero", + [ + pytest.param( + True, + 1308064, + marks=pytest.mark.skipif( + not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x" + ), + ), + pytest.param( + False, + 437686, + marks=pytest.mark.skipif( + not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x" + ), + ), + ], +) def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif): with rasterio.open(path_rgb_byte_tif) as src: if test3d: @@ -1352,6 +1370,44 @@ def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif): ) assert np.ma.is_masked(source) assert np.count_nonzero(out[out != 99]) == count_nonzero + assert not np.ma.is_masked(out) + + +@pytest.mark.parametrize( + "test3d,count_nonzero", + [ + pytest.param( + True, + 1312959, + marks=pytest.mark.skipif( + not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x" + ), + ), + pytest.param( + False, + 438113, + marks=pytest.mark.skipif( + not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x" + ), + ), + ] +) +def test_reproject_masked_masked_output(test3d, count_nonzero, path_rgb_byte_tif): + with rasterio.open(path_rgb_byte_tif) as src: + if test3d: + inp = src.read(masked=True) + else: + inp = src.read(1, masked=True) + out = np.ma.masked_array(np.empty(inp.shape, dtype=np.uint8)) + out, _ = reproject( + inp, + out, + src_transform=src.transform, + src_crs=src.crs, + dst_transform=DST_TRANSFORM, + dst_crs="EPSG:3857", + ) + assert np.count_nonzero(out[out != np.ma.masked]) == count_nonzero @pytest.mark.parametrize("method", SUPPORTED_RESAMPLING)