Fix error in reproject when destination is a tuple. (#2369)

* Fix error in reproject when destination is a tuple.

* Add unit test for reprojecting to specified output bands.

Co-authored-by: Sean Gillies <sean.gillies@gmail.com>
This commit is contained in:
Samuel Kogler 2022-04-11 23:18:47 +02:00 committed by GitHub
parent dd729ef594
commit c68d74ddb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 2 deletions

View File

@ -572,7 +572,10 @@ def _reproject(
try:
exc_wrap_int(oWarper.Initialize(psWOptions))
rows, cols = destination.shape[-2:]
if isinstance(destination, tuple):
rows, cols = destination[3]
else:
rows, cols = destination.shape[-2:]
log.debug(
"Chunk and warp window: %d, %d, %d, %d.",

Binary file not shown.

View File

@ -2030,6 +2030,54 @@ def test_reproject_error_propagation(http_error_server, caplog):
assert len([rec for rec in caplog.records if "Retrying again" in rec.message]) == 2
def test_reproject_to_specified_output_bands():
"""
Reproject multiple input rasters to a single output raster, joining their bands.
In this example, we concatenate a RGB raster with a mocked NIR image, while reprojecting and resampling both.
"""
with rasterio.open('tests/data/rgb1.tif') as src_rgb, \
rasterio.open('tests/data/rgb1_fake_nir_epsg3857.tif') as src_nir:
output_crs = CRS.from_epsg(4326)
output_transform, output_width, output_height = calculate_default_transform(
src_rgb.crs,
output_crs,
src_rgb.width,
src_rgb.height,
*src_rgb.bounds)
with rasterio.MemoryFile() as mem:
with mem.open(
driver="GTiff",
width=output_width,
height=output_height,
count=5,
transform=output_transform,
dtype="uint8",
nodata=0,
crs=output_crs,
) as out: # type: rasterio.io.DatasetWriter
reproject(
rasterio.band(src_rgb, src_rgb.indexes),
rasterio.band(out, src_rgb.indexes),
resampling=Resampling.nearest,
)
reproject(
rasterio.band(src_nir, 1),
rasterio.band(out, 4),
resampling=Resampling.nearest,
)
with mem.open() as out: # type: rasterio.DatasetReader
assert out.count == 5
for i in range(1, 5):
band_data = out.read(i)
assert (band_data != 0).any()
band_data = out.read(5)
assert (band_data == 0).all()
def test_rpcs_non_epsg4326():
with pytest.raises(RPCError):
with rasterio.open('tests/data/RGB.byte.rpc.vrt') as src:
@ -2040,4 +2088,4 @@ def test_rpcs_non_epsg4326():
rpcs=src_rpcs,
dst_crs="EPSG:4326",
resampling=Resampling.nearest,
)
)