mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
Use mask as alpha when reprojecting a masked array (#3156)
* Use mask as alpha when reprojecting a masked array Resolves #2575 * Remove unnecessary nodata wranging * Masked reprojection results depend on GDAL version * Fix test expectations and mask scaling value The tests of this PR are now only run for GDAL 3.8+ because the outputs are too variable. * Update change log
This commit is contained in:
parent
d917c5cd38
commit
1ac101d8d7
@ -6,6 +6,9 @@ Next (TBD)
|
||||
|
||||
Bug fixes:
|
||||
|
||||
- When reprojecting a masked array, we now use the mask (reduced) as an alpha
|
||||
band. There is now also an option to create an alpha band in the output, and
|
||||
turn that into a mask when returning a mask array (#3156).
|
||||
- Find installed GDAL data directory by searching for gdalvrt.xsd (#3157).
|
||||
- Allow rasterio.open() to receive instances of MemoryFile (#3145).
|
||||
- Leaks of CSL string lists in get/set_proj_data_search_path() have been fixed
|
||||
|
||||
@ -42,3 +42,5 @@ ctypedef np.float64_t DTYPE_FLOAT64_t
|
||||
cdef bint in_dtype_range(value, dtype)
|
||||
|
||||
cdef int io_auto(image, GDALRasterBandH band, bint write, int resampling=*) except -1
|
||||
cdef int io_band(GDALRasterBandH band, int mode, double x0, double y0, double width, double height, object data, int resampling=*) except -1
|
||||
cdef int io_multi_band(GDALDatasetH hds, int mode, double x0, double y0, double width, double height, object data, Py_ssize_t[:] indexes, int resampling=*) except -1
|
||||
|
||||
@ -38,7 +38,7 @@ from libc.math cimport HUGE_VAL
|
||||
from rasterio._base cimport get_driver_name
|
||||
from rasterio._err cimport exc_wrap, exc_wrap_pointer, exc_wrap_int, StackChecker
|
||||
from rasterio._io cimport (
|
||||
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto)
|
||||
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto, io_band, io_multi_band)
|
||||
from rasterio._features cimport GeomBuilder, OGRGeomBuilder
|
||||
from rasterio.crs cimport CRS
|
||||
|
||||
@ -282,9 +282,9 @@ def _reproject(
|
||||
nodata value of the destination image (if set), the value of
|
||||
src_nodata, or 0 (gdal default).
|
||||
src_alpha : int, optional
|
||||
Index of a band to use as the alpha band when warping.
|
||||
Index of a band to use as the source alpha band when warping.
|
||||
dst_alpha : int, optional
|
||||
Index of a band to use as the alpha band when warping.
|
||||
Index of a band to use as the destination alpha band when warping.
|
||||
resampling : int
|
||||
Resampling method to use. One of the following:
|
||||
Resampling.nearest,
|
||||
@ -360,35 +360,54 @@ def _reproject(
|
||||
cdef MemoryDataset src_mem = None
|
||||
|
||||
try:
|
||||
|
||||
# If the source is an ndarray, we copy to a MEM dataset.
|
||||
# We need a src_transform and src_dst in this case. These will
|
||||
# be copied to the MEM dataset.
|
||||
if dtypes.is_ndarray(source):
|
||||
if not src_crs:
|
||||
raise CRSError("Missing src_crs.")
|
||||
if src_nodata is None and hasattr(source, 'fill_value'):
|
||||
# source is a masked array
|
||||
src_nodata = source.fill_value
|
||||
|
||||
# ensure data converted to numpy array
|
||||
source = np.asanyarray(source)
|
||||
if hasattr(source, "mask"):
|
||||
source = np.ma.asanyarray(source)
|
||||
else:
|
||||
source = np.asanyarray(source)
|
||||
|
||||
# Convert 2D single-band arrays to 3D multi-band.
|
||||
if len(source.shape) == 2:
|
||||
source = source.reshape(1, *source.shape)
|
||||
|
||||
src_count = source.shape[0]
|
||||
src_bidx = range(1, src_count + 1)
|
||||
src_mem = MemoryDataset(source,
|
||||
transform=format_transform(src_transform),
|
||||
gcps=gcps,
|
||||
rpcs=rpcs,
|
||||
crs=src_crs,
|
||||
copy=True)
|
||||
|
||||
if hasattr(source, "mask"):
|
||||
mask = ~np.logical_or.reduce(source.mask) * np.uint8(255)
|
||||
source_arr = np.concatenate((source.data, [mask]))
|
||||
src_alpha = src_alpha or source_arr.shape[0]
|
||||
else:
|
||||
source_arr = source
|
||||
|
||||
src_mem = MemoryDataset(
|
||||
source_arr,
|
||||
transform=format_transform(src_transform),
|
||||
gcps=gcps,
|
||||
rpcs=rpcs,
|
||||
crs=src_crs,
|
||||
copy=True,
|
||||
)
|
||||
src_dataset = src_mem.handle()
|
||||
|
||||
if src_alpha:
|
||||
for i in range(source_arr.shape[0]):
|
||||
GDALDeleteRasterNoDataValue(GDALGetRasterBand(src_dataset, i+1))
|
||||
GDALSetRasterColorInterpretation(
|
||||
GDALGetRasterBand(src_dataset, src_alpha),
|
||||
<GDALColorInterp>6,
|
||||
)
|
||||
|
||||
# If the source is a rasterio MultiBand, no copy necessary.
|
||||
# A MultiBand is a tuple: (dataset, bidx, dtype, shape(2d))
|
||||
elif isinstance(source, tuple):
|
||||
|
||||
rdr, src_bidx, dtype, shape = source
|
||||
if isinstance(src_bidx, int):
|
||||
src_bidx = [src_bidx]
|
||||
@ -408,12 +427,17 @@ def _reproject(
|
||||
|
||||
# Next, do the same for the destination raster.
|
||||
try:
|
||||
|
||||
if dtypes.is_ndarray(destination):
|
||||
if not dst_crs:
|
||||
raise CRSError("Missing dst_crs.")
|
||||
# ensure data converted to numpy array
|
||||
destination = np.asanyarray(destination)
|
||||
|
||||
dst_nodata = dst_nodata or src_nodata
|
||||
|
||||
if hasattr(destination, "mask"):
|
||||
destination = np.ma.asanyarray(destination)
|
||||
else:
|
||||
destination = np.asanyarray(destination)
|
||||
|
||||
if len(destination.shape) == 2:
|
||||
destination = destination.reshape(1, *destination.shape)
|
||||
|
||||
@ -426,27 +450,33 @@ def _reproject(
|
||||
raise ValueError("Invalid destination shape")
|
||||
dst_bidx = src_bidx
|
||||
|
||||
mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs)
|
||||
if hasattr(destination, "mask"):
|
||||
count, height, width = destination.shape
|
||||
msk = np.logical_or.reduce(destination.mask)
|
||||
if msk == np.ma.nomask:
|
||||
msk = np.zeros((height, width), dtype="bool")
|
||||
msk = ~msk * np.uint8(255)
|
||||
dest_arr = np.concatenate((destination.data, [msk]))
|
||||
dst_alpha = dst_alpha or dest_arr.shape[0]
|
||||
else:
|
||||
dest_arr = destination
|
||||
|
||||
mem_raster = MemoryDataset(dest_arr, transform=format_transform(dst_transform), crs=dst_crs)
|
||||
dst_dataset = mem_raster.handle()
|
||||
|
||||
if dst_alpha:
|
||||
for i in range(destination.shape[0]):
|
||||
for i in range(dest_arr.shape[0]):
|
||||
GDALDeleteRasterNoDataValue(GDALGetRasterBand(dst_dataset, i+1))
|
||||
|
||||
GDALSetRasterColorInterpretation(GDALGetRasterBand(dst_dataset, dst_alpha), <GDALColorInterp>6)
|
||||
GDALSetRasterColorInterpretation(
|
||||
GDALGetRasterBand(dst_dataset, dst_alpha),
|
||||
<GDALColorInterp>6,
|
||||
)
|
||||
|
||||
GDALSetDescription(
|
||||
dst_dataset, "Temporary destination dataset for _reproject()")
|
||||
|
||||
log.debug("Created temp destination dataset.")
|
||||
|
||||
if dst_nodata is None:
|
||||
if hasattr(destination, "fill_value"):
|
||||
# destination is a masked array
|
||||
dst_nodata = destination.fill_value
|
||||
elif src_nodata is not None:
|
||||
dst_nodata = src_nodata
|
||||
|
||||
elif isinstance(destination, tuple):
|
||||
udr, dst_bidx, _, _ = destination
|
||||
if isinstance(dst_bidx, int):
|
||||
@ -609,8 +639,27 @@ def _reproject(
|
||||
except CPLE_BaseError as base:
|
||||
raise WarpOperationError("Chunk and warp failed") from base
|
||||
|
||||
if dtypes.is_ndarray(destination):
|
||||
exc_wrap_int(io_auto(destination, dst_dataset, 0))
|
||||
if mem_raster is not None:
|
||||
count, height, width = dest_arr.shape
|
||||
if hasattr(destination, "mask"):
|
||||
# Pick off the alpha band and make a mask of it.
|
||||
# TODO: do this efficiently, not copying unless necessary.
|
||||
indexes = np.arange(1, count, dtype='intp')
|
||||
io_multi_band(dst_dataset, 0, 0.0, 0.0, width, height, destination, indexes)
|
||||
alpha_arr = np.empty((height, width), dtype=dest_arr.dtype)
|
||||
io_band(mem_raster.band(count), 0, 0.0, 0.0, width, height, alpha_arr)
|
||||
destination = np.ma.masked_array(
|
||||
destination.data,
|
||||
mask=np.repeat(
|
||||
~(alpha_arr.astype("bool"))[np.newaxis, :, :],
|
||||
count - 1,
|
||||
axis=0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
exc_wrap_int(io_auto(destination, dst_dataset, 0))
|
||||
|
||||
return destination
|
||||
|
||||
# Clean up transformer, warp options, and dataset handles.
|
||||
finally:
|
||||
@ -629,7 +678,6 @@ def _reproject(
|
||||
if src_mem is not None:
|
||||
src_mem.close()
|
||||
|
||||
|
||||
def _calculate_default_transform(
|
||||
src_crs,
|
||||
dst_crs,
|
||||
@ -1024,11 +1072,9 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
|
||||
# raise an exception instead.
|
||||
|
||||
if add_alpha:
|
||||
|
||||
if src_alpha_band:
|
||||
raise WarpOptionsError(
|
||||
"The VRT already has an alpha band, adding a new one is not supported")
|
||||
|
||||
else:
|
||||
dst_alpha_band = src_dataset.count + 1
|
||||
self.dst_nodata = None
|
||||
|
||||
@ -172,6 +172,7 @@ def reproject(
|
||||
dst_resolution=None,
|
||||
src_alpha=0,
|
||||
dst_alpha=0,
|
||||
masked=False,
|
||||
resampling=Resampling.nearest,
|
||||
num_threads=1,
|
||||
init_dest_nodata=True,
|
||||
@ -250,6 +251,8 @@ def reproject(
|
||||
Index of a band to use as the alpha band when warping.
|
||||
dst_alpha : int, optional
|
||||
Index of a band to use as the alpha band when warping.
|
||||
masked: bool, optional
|
||||
If True and destination is None, return a masked array.
|
||||
resampling: int, rasterio.enums.Resampling
|
||||
Resampling method to use.
|
||||
Default is :attr:`rasterio.enums.Resampling.nearest`.
|
||||
@ -378,8 +381,10 @@ def reproject(
|
||||
destination = np.empty(
|
||||
(int(dst_count), int(dst_height), int(dst_width)), dtype=source.dtype
|
||||
)
|
||||
if masked:
|
||||
destination = np.ma.masked_array(destination).filled(dst_nodata)
|
||||
|
||||
_reproject(
|
||||
dest = _reproject(
|
||||
source,
|
||||
destination,
|
||||
src_transform=src_transform,
|
||||
@ -400,7 +405,7 @@ def reproject(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return destination, dst_transform
|
||||
return dest, dst_transform
|
||||
|
||||
|
||||
def aligned_target(transform, width, height, resolution):
|
||||
|
||||
@ -1333,7 +1333,25 @@ def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
|
||||
@pytest.mark.parametrize(
|
||||
"test3d,count_nonzero",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
1308064,
|
||||
marks=pytest.mark.skipif(
|
||||
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
437686,
|
||||
marks=pytest.mark.skipif(
|
||||
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
with rasterio.open(path_rgb_byte_tif) as src:
|
||||
if test3d:
|
||||
@ -1352,6 +1370,44 @@ def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
)
|
||||
assert np.ma.is_masked(source)
|
||||
assert np.count_nonzero(out[out != 99]) == count_nonzero
|
||||
assert not np.ma.is_masked(out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test3d,count_nonzero",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
1312959,
|
||||
marks=pytest.mark.skipif(
|
||||
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
438113,
|
||||
marks=pytest.mark.skipif(
|
||||
not gdal_version.at_least("3.8"), reason="Requires GDAL 3.8.x"
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_reproject_masked_masked_output(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
with rasterio.open(path_rgb_byte_tif) as src:
|
||||
if test3d:
|
||||
inp = src.read(masked=True)
|
||||
else:
|
||||
inp = src.read(1, masked=True)
|
||||
out = np.ma.masked_array(np.empty(inp.shape, dtype=np.uint8))
|
||||
out, _ = reproject(
|
||||
inp,
|
||||
out,
|
||||
src_transform=src.transform,
|
||||
src_crs=src.crs,
|
||||
dst_transform=DST_TRANSFORM,
|
||||
dst_crs="EPSG:3857",
|
||||
)
|
||||
assert np.count_nonzero(out[out != np.ma.masked]) == count_nonzero
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", SUPPORTED_RESAMPLING)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user