mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
ENH: Add support for objects with __array__ in reproject() (#1959)
This commit is contained in:
parent
5433425cd8
commit
57c2efd01e
@ -338,13 +338,13 @@ def _reproject(
|
||||
# We need a src_transform and src_dst in this case. These will
|
||||
# be copied to the MEM dataset.
|
||||
if dtypes.is_ndarray(source):
|
||||
|
||||
if not src_crs:
|
||||
raise CRSError("Missing src_crs.")
|
||||
|
||||
if src_nodata is None and hasattr(source, 'fill_value'):
|
||||
# source is a masked array
|
||||
src_nodata = source.fill_value
|
||||
# ensure data converted to numpy array
|
||||
source = np.array(source, copy=False)
|
||||
# Convert 2D single-band arrays to 3D multi-band.
|
||||
if len(source.shape) == 2:
|
||||
source = source.reshape(1, *source.shape)
|
||||
@ -381,10 +381,10 @@ def _reproject(
|
||||
try:
|
||||
|
||||
if dtypes.is_ndarray(destination):
|
||||
|
||||
if not dst_crs:
|
||||
raise CRSError("Missing dst_crs.")
|
||||
|
||||
# ensure data converted to numpy array
|
||||
destination = np.array(destination, copy=False)
|
||||
if len(destination.shape) == 2:
|
||||
destination = destination.reshape(1, *destination.shape)
|
||||
|
||||
|
||||
@ -1092,6 +1092,59 @@ def test_reproject_resampling(path_rgb_byte_tif, method):
|
||||
|
||||
assert np.count_nonzero(out) in expected[method]
|
||||
|
||||
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
|
||||
def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
class DataArray:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.data.dtype
|
||||
|
||||
with rasterio.open(path_rgb_byte_tif) as src:
|
||||
if test3d:
|
||||
source = DataArray(src.read())
|
||||
else:
|
||||
source = DataArray(src.read(1))
|
||||
out = DataArray(np.empty(source.data.shape, dtype=np.uint8))
|
||||
reproject(
|
||||
source,
|
||||
out,
|
||||
src_transform=src.transform,
|
||||
src_crs=src.crs,
|
||||
src_nodata=src.nodata,
|
||||
dst_transform=DST_TRANSFORM,
|
||||
dst_crs="EPSG:3857",
|
||||
dst_nodata=99,
|
||||
)
|
||||
assert isinstance(out, DataArray)
|
||||
assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)])
|
||||
def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif):
|
||||
with rasterio.open(path_rgb_byte_tif) as src:
|
||||
if test3d:
|
||||
source = src.read(masked=True)
|
||||
else:
|
||||
source = src.read(1, masked=True)
|
||||
out = np.empty(source.shape, dtype=np.uint8)
|
||||
reproject(
|
||||
source,
|
||||
out,
|
||||
src_transform=src.transform,
|
||||
src_crs=src.crs,
|
||||
dst_transform=DST_TRANSFORM,
|
||||
dst_crs="EPSG:3857",
|
||||
dst_nodata=99,
|
||||
)
|
||||
assert np.ma.is_masked(source)
|
||||
assert np.count_nonzero(out[out != 99]) == count_nonzero
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", SUPPORTED_RESAMPLING)
|
||||
def test_reproject_resampling_alpha(method):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user