From 57c2efd01eee0fee85fd5032930211069fb9c2c2 Mon Sep 17 00:00:00 2001 From: "Alan D. Snow" Date: Mon, 13 Jul 2020 18:46:18 -0500 Subject: [PATCH] ENH: Add support for objects with __array__ in reproject() (#1959) --- rasterio/_warp.pyx | 8 +++---- tests/test_warp.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/rasterio/_warp.pyx b/rasterio/_warp.pyx index a6c4ae41..6010003c 100644 --- a/rasterio/_warp.pyx +++ b/rasterio/_warp.pyx @@ -338,13 +338,13 @@ def _reproject( # We need a src_transform and src_dst in this case. These will # be copied to the MEM dataset. if dtypes.is_ndarray(source): - if not src_crs: raise CRSError("Missing src_crs.") - if src_nodata is None and hasattr(source, 'fill_value'): # source is a masked array src_nodata = source.fill_value + # ensure data converted to numpy array + source = np.array(source, copy=False) # Convert 2D single-band arrays to 3D multi-band. if len(source.shape) == 2: source = source.reshape(1, *source.shape) @@ -381,10 +381,10 @@ def _reproject( try: if dtypes.is_ndarray(destination): - if not dst_crs: raise CRSError("Missing dst_crs.") - + # ensure data converted to numpy array + destination = np.array(destination, copy=False) if len(destination.shape) == 2: destination = destination.reshape(1, *destination.shape) diff --git a/tests/test_warp.py b/tests/test_warp.py index db918760..06a5024e 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -1092,6 +1092,59 @@ def test_reproject_resampling(path_rgb_byte_tif, method): assert np.count_nonzero(out) in expected[method] +@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)]) +def test_reproject_array_interface(test3d, count_nonzero, path_rgb_byte_tif): + class DataArray: + def __init__(self, data): + self.data = data + + def __array__(self, dtype=None): + return self.data + + @property + def dtype(self): + return self.data.dtype + + with rasterio.open(path_rgb_byte_tif) as src: + if test3d: + source = DataArray(src.read()) + else: + source = DataArray(src.read(1)) + out = DataArray(np.empty(source.data.shape, dtype=np.uint8)) + reproject( + source, + out, + src_transform=src.transform, + src_crs=src.crs, + src_nodata=src.nodata, + dst_transform=DST_TRANSFORM, + dst_crs="EPSG:3857", + dst_nodata=99, + ) + assert isinstance(out, DataArray) + assert np.count_nonzero(out.data[out.data != 99]) == count_nonzero + + +@pytest.mark.parametrize("test3d,count_nonzero", [(True, 1309625), (False, 437686)]) +def test_reproject_masked(test3d, count_nonzero, path_rgb_byte_tif): + with rasterio.open(path_rgb_byte_tif) as src: + if test3d: + source = src.read(masked=True) + else: + source = src.read(1, masked=True) + out = np.empty(source.shape, dtype=np.uint8) + reproject( + source, + out, + src_transform=src.transform, + src_crs=src.crs, + dst_transform=DST_TRANSFORM, + dst_crs="EPSG:3857", + dst_nodata=99, + ) + assert np.ma.is_masked(source) + assert np.count_nonzero(out[out != 99]) == count_nonzero + @pytest.mark.parametrize("method", SUPPORTED_RESAMPLING) def test_reproject_resampling_alpha(method):