ENH: Add support for objects with __array__ in reproject() (#1959)

This commit is contained in:
Alan D. Snow 2020-07-13 18:46:18 -05:00 committed by GitHub
parent 5433425cd8
commit 57c2efd01e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 4 deletions

View File

@ -338,13 +338,13 @@ def _reproject(
# We need a src_transform and src_dst in this case. These will
# be copied to the MEM dataset.
if dtypes.is_ndarray(source):
if not src_crs:
raise CRSError("Missing src_crs.")
if src_nodata is None and hasattr(source, 'fill_value'):
# source is a masked array
src_nodata = source.fill_value
# ensure data converted to numpy array
source = np.array(source, copy=False)
# Convert 2D single-band arrays to 3D multi-band.
if len(source.shape) == 2:
source = source.reshape(1, *source.shape)
@ -381,10 +381,10 @@ def _reproject(
try:
if dtypes.is_ndarray(destination):
if not dst_crs:
raise CRSError("Missing dst_crs.")
# ensure data converted to numpy array
destination = np.array(destination, copy=False)
if len(destination.shape) == 2:
destination = destination.reshape(1, *destination.shape)

View File

@ -1092,6 +1092,59 @@ def test_reproject_resampling(path_rgb_byte_tif, method):
assert np.count_nonzero(out) in expected[method]
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif):
class DataArray:
def __init__(self, data):
self.data = data
def __array__(self, dtype=None):
return self.data
@property
def dtype(self):
return self.data.dtype
with rasterio.open(path_rgb_byte_tif) as src:
if test3d:
source = DataArray(src.read())
else:
source = DataArray(src.read(1))
out = DataArray(np.empty(source.data.shape, dtype=np.uint8))
reproject(
source,
out,
src_transform=src.transform,
src_crs=src.crs,
src_nodata=src.nodata,
dst_transform=DST_TRANSFORM,
dst_crs="EPSG:3857",
dst_nodata=99,
)
assert isinstance(out, DataArray)
assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
with rasterio.open(path_rgb_byte_tif) as src:
if test3d:
source = src.read(masked=True)
else:
source = src.read(1, masked=True)
out = np.empty(source.shape, dtype=np.uint8)
reproject(
source,
out,
src_transform=src.transform,
src_crs=src.crs,
dst_transform=DST_TRANSFORM,
dst_crs="EPSG:3857",
dst_nodata=99,
)
assert np.ma.is_masked(source)
assert np.count_nonzero(out[out != 99]) == count_nonzero
@pytest.mark.parametrize("method", SUPPORTED_RESAMPLING)
def test_reproject_resampling_alpha(method):