Use mask as alpha when reprojecting a masked array (#3156)

* Use mask as alpha when reprojecting a masked array

Resolves #2575

* Remove unnecessary nodata wranging

* Masked reprojection results depend on GDAL version

* Fix test expectations and mask scaling value

The tests of this PR are now only run for GDAL 3.8+ because the
outputs are too variable.

* Update change log
This commit is contained in:
Sean Gillies 2024-08-30 10:14:35 -06:00 committed by GitHub
parent d917c5cd38
commit 1ac101d8d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 149 additions and 37 deletions

View File

@ -6,6 +6,9 @@ Next (TBD)
Bug fixes:
- When reprojecting a masked array, we now use the mask (reduced) as an alpha
band. There is now also an option to create an alpha band in the output, and
turn that into a mask when returning a mask array (#3156).
- Find installed GDAL data directory by searching for gdalvrt.xsd (#3157).
- Allow rasterio.open() to receive instances of MemoryFile (#3145).
- Leaks of CSL string lists in get/set_proj_data_search_path() have been fixed

View File

@ -42,3 +42,5 @@ ctypedef np.float64_t DTYPE_FLOAT64_t
cdef bint in_dtype_range(value, dtype)
cdef int io_auto(image, GDALRasterBandH band, bint write, int resampling=*) except -1
cdef int io_band(GDALRasterBandH band, int mode, double x0, double y0, double width, double height, object data, int resampling=*) except -1
cdef int io_multi_band(GDALDatasetH hds, int mode, double x0, double y0, double width, double height, object data, Py_ssize_t[:] indexes, int resampling=*) except -1

View File

@ -38,7 +38,7 @@ from libc.math cimport HUGE_VAL
from rasterio._base cimport get_driver_name
from rasterio._err cimport exc_wrap, exc_wrap_pointer, exc_wrap_int, StackChecker
from rasterio._io cimport (
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto)
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto, io_band, io_multi_band)
from rasterio._features cimport GeomBuilder, OGRGeomBuilder
from rasterio.crs cimport CRS
@ -282,9 +282,9 @@ def _reproject(
nodata value of the destination image (if set), the value of
src_nodata, or 0 (gdal default).
src_alpha : int, optional
Index of a band to use as the alpha band when warping.
Index of a band to use as the source alpha band when warping.
dst_alpha : int, optional
Index of a band to use as the alpha band when warping.
Index of a band to use as the destination alpha band when warping.
resampling : int
Resampling method to use. One of the following:
Resampling.nearest,
@ -360,35 +360,54 @@ def _reproject(
cdef MemoryDataset src_mem = None
try:
# If the source is an ndarray, we copy to a MEM dataset.
# We need a src_transform and src_dst in this case. These will
# be copied to the MEM dataset.
if dtypes.is_ndarray(source):
if not src_crs:
raise CRSError("Missing src_crs.")
if src_nodata is None and hasattr(source, 'fill_value'):
# source is a masked array
src_nodata = source.fill_value
# ensure data converted to numpy array
source = np.asanyarray(source)
if hasattr(source, "mask"):
source = np.ma.asanyarray(source)
else:
source = np.asanyarray(source)
# Convert 2D single-band arrays to 3D multi-band.
if len(source.shape) == 2:
source = source.reshape(1, *source.shape)
src_count = source.shape[0]
src_bidx = range(1, src_count + 1)
src_mem = MemoryDataset(source,
transform=format_transform(src_transform),
gcps=gcps,
rpcs=rpcs,
crs=src_crs,
copy=True)
if hasattr(source, "mask"):
mask = ~np.logical_or.reduce(source.mask) * np.uint8(255)
source_arr = np.concatenate((source.data, [mask]))
src_alpha = src_alpha or source_arr.shape[0]
else:
source_arr = source
src_mem = MemoryDataset(
source_arr,
transform=format_transform(src_transform),
gcps=gcps,
rpcs=rpcs,
crs=src_crs,
copy=True,
)
src_dataset = src_mem.handle()
if src_alpha:
for i in range(source_arr.shape[0]):
GDALDeleteRasterNoDataValue(GDALGetRasterBand(src_dataset, i+1))
GDALSetRasterColorInterpretation(
GDALGetRasterBand(src_dataset, src_alpha),
<GDALColorInterp>6,
)
# If the source is a rasterio MultiBand, no copy necessary.
# A MultiBand is a tuple: (dataset, bidx, dtype, shape(2d))
elif isinstance(source, tuple):
rdr, src_bidx, dtype, shape = source
if isinstance(src_bidx, int):
src_bidx = [src_bidx]
@ -408,12 +427,17 @@ def _reproject(
# Next, do the same for the destination raster.
try:
if dtypes.is_ndarray(destination):
if not dst_crs:
raise CRSError("Missing dst_crs.")
# ensure data converted to numpy array
destination = np.asanyarray(destination)
dst_nodata = dst_nodata or src_nodata
if hasattr(destination, "mask"):
destination = np.ma.asanyarray(destination)
else:
destination = np.asanyarray(destination)
if len(destination.shape) == 2:
destination = destination.reshape(1, *destination.shape)
@ -426,27 +450,33 @@ def _reproject(
raise ValueError("Invalid destination shape")
dst_bidx = src_bidx
mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs)
if hasattr(destination, "mask"):
count, height, width = destination.shape
msk = np.logical_or.reduce(destination.mask)
if msk == np.ma.nomask:
msk = np.zeros((height, width), dtype="bool")
msk = ~msk * np.uint8(255)
dest_arr = np.concatenate((destination.data, [msk]))
dst_alpha = dst_alpha or dest_arr.shape[0]
else:
dest_arr = destination
mem_raster = MemoryDataset(dest_arr, transform=format_transform(dst_transform), crs=dst_crs)
dst_dataset = mem_raster.handle()
if dst_alpha:
for i in range(destination.shape[0]):
for i in range(dest_arr.shape[0]):
GDALDeleteRasterNoDataValue(GDALGetRasterBand(dst_dataset, i+1))
GDALSetRasterColorInterpretation(GDALGetRasterBand(dst_dataset, dst_alpha), <GDALColorInterp>6)
GDALSetRasterColorInterpretation(
GDALGetRasterBand(dst_dataset, dst_alpha),
<GDALColorInterp>6,
)
GDALSetDescription(
dst_dataset, "Temporary destination dataset for _reproject()")
log.debug("Created temp destination dataset.")
if dst_nodata is None:
if hasattr(destination, "fill_value"):
# destination is a masked array
dst_nodata = destination.fill_value
elif src_nodata is not None:
dst_nodata = src_nodata
elif isinstance(destination, tuple):
udr, dst_bidx, _, _ = destination
if isinstance(dst_bidx, int):
@ -609,8 +639,27 @@ def _reproject(
except CPLE_BaseError as base:
raise WarpOperationError("Chunk and warp failed") from base
if dtypes.is_ndarray(destination):
exc_wrap_int(io_auto(destination, dst_dataset, 0))
if mem_raster is not None:
count, height, width = dest_arr.shape
if hasattr(destination, "mask"):
# Pick off the alpha band and make a mask of it.
# TODO: do this efficiently, not copying unless necessary.
indexes = np.arange(1, count, dtype='intp')
io_multi_band(dst_dataset, 0, 0.0, 0.0, width, height, destination, indexes)
alpha_arr = np.empty((height, width), dtype=dest_arr.dtype)
io_band(mem_raster.band(count), 0, 0.0, 0.0, width, height, alpha_arr)
destination = np.ma.masked_array(
destination.data,
mask=np.repeat(
~(alpha_arr.astype("bool"))[np.newaxis, :, :],
count - 1,
axis=0,
)
)
else:
exc_wrap_int(io_auto(destination, dst_dataset, 0))
return destination
# Clean up transformer, warp options, and dataset handles.
finally:
@ -629,7 +678,6 @@ def _reproject(
if src_mem is not None:
src_mem.close()
def _calculate_default_transform(
src_crs,
dst_crs,
@ -1024,11 +1072,9 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
# raise an exception instead.
if add_alpha:
if src_alpha_band:
raise WarpOptionsError(
"The VRT already has an alpha band, adding a new one is not supported")
else:
dst_alpha_band = src_dataset.count + 1
self.dst_nodata = None

View File

@ -172,6 +172,7 @@ def reproject(
dst_resolution=None,
src_alpha=0,
dst_alpha=0,
masked=False,
resampling=Resampling.nearest,
num_threads=1,
init_dest_nodata=True,
@ -250,6 +251,8 @@ def reproject(
Index of a band to use as the alpha band when warping.
dst_alpha : int, optional
Index of a band to use as the alpha band when warping.
masked: bool, optional
If True and destination is None, return a masked array.
resampling: int, rasterio.enums.Resampling
Resampling method to use.
Default is :attr:`rasterio.enums.Resampling.nearest`.
@ -378,8 +381,10 @@ def reproject(
destination = np.empty(
(int(dst_count), int(dst_height), int(dst_width)), dtype=source.dtype
)
if masked:
destination = np.ma.masked_array(destination).filled(dst_nodata)
_reproject(
dest = _reproject(
source,
destination,
src_transform=src_transform,
@ -400,7 +405,7 @@ def reproject(
**kwargs
)
return destination, dst_transform
return dest, dst_transform
def aligned_target(transform, width, height, resolution):

View File

@ -1333,7 +1333,25 @@ def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif):
assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
@pytest.mark.parametrize(
"test3d,count_nonzero",
[
pytest.param(
True,
1308064,
marks=pytest.mark.skipif(
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
),
),
pytest.param(
False,
437686,
marks=pytest.mark.skipif(
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
),
),
],
)
def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
with rasterio.open(path_rgb_byte_tif) as src:
if test3d:
@ -1352,6 +1370,44 @@ def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
)
assert np.ma.is_masked(source)
assert np.count_nonzero(out[out != 99]) == count_nonzero
assert not np.ma.is_masked(out)
@pytest.mark.parametrize(
"test3d,count_nonzero",
[
pytest.param(
True,
1312959,
marks=pytest.mark.skipif(
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
),
),
pytest.param(
False,
438113,
marks=pytest.mark.skipif(
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
),
),
]
)
def test_reproject_masked_masked_output(test3d, count_nonzero, path_rgb_byte_tif):
with rasterio.open(path_rgb_byte_tif) as src:
if test3d:
inp = src.read(masked=True)
else:
inp = src.read(1, masked=True)
out = np.ma.masked_array(np.empty(inp.shape, dtype=np.uint8))
out, _ = reproject(
inp,
out,
src_transform=src.transform,
src_crs=src.crs,
dst_transform=DST_TRANSFORM,
dst_crs="EPSG:3857",
)
assert np.count_nonzero(out[out != np.ma.masked]) == count_nonzero
@pytest.mark.parametrize("method", SUPPORTED_RESAMPLING)