diff --git a/rasterio/_warp.pyx b/rasterio/_warp.pyx index d602b178..82c888c1 100644 --- a/rasterio/_warp.pyx +++ b/rasterio/_warp.pyx @@ -572,7 +572,10 @@ def _reproject( try: exc_wrap_int(oWarper.Initialize(psWOptions)) - rows, cols = destination.shape[-2:] + if isinstance(destination, tuple): + rows, cols = destination[3] + else: + rows, cols = destination.shape[-2:] log.debug( "Chunk and warp window: %d, %d, %d, %d.", diff --git a/tests/data/rgb1_fake_nir_epsg3857.tif b/tests/data/rgb1_fake_nir_epsg3857.tif new file mode 100644 index 00000000..5f24262c Binary files /dev/null and b/tests/data/rgb1_fake_nir_epsg3857.tif differ diff --git a/tests/test_warp.py b/tests/test_warp.py index 4dd01ea9..93a315a7 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -2030,6 +2030,54 @@ def test_reproject_error_propagation(http_error_server, caplog): assert len([rec for rec in caplog.records if "Retrying again" in rec.message]) == 2 +def test_reproject_to_specified_output_bands(): + """ + Reproject multiple input rasters to a single output raster, joining their bands. + + In this example, we concatenate a RGB raster with a mocked NIR image, while reprojecting and resampling both. + """ + with rasterio.open('tests/data/rgb1.tif') as src_rgb, \ + rasterio.open('tests/data/rgb1_fake_nir_epsg3857.tif') as src_nir: + output_crs = CRS.from_epsg(4326) + output_transform, output_width, output_height = calculate_default_transform( + src_rgb.crs, + output_crs, + src_rgb.width, + src_rgb.height, + *src_rgb.bounds) + with rasterio.MemoryFile() as mem: + with mem.open( + driver="GTiff", + width=output_width, + height=output_height, + count=5, + transform=output_transform, + dtype="uint8", + nodata=0, + crs=output_crs, + ) as out: # type: rasterio.io.DatasetWriter + reproject( + rasterio.band(src_rgb, src_rgb.indexes), + rasterio.band(out, src_rgb.indexes), + resampling=Resampling.nearest, + ) + reproject( + rasterio.band(src_nir, 1), + rasterio.band(out, 4), + resampling=Resampling.nearest, + ) + + with mem.open() as out: # type: rasterio.DatasetReader + assert out.count == 5 + + for i in range(1, 5): + band_data = out.read(i) + assert (band_data != 0).any() + + band_data = out.read(5) + assert (band_data == 0).all() + + def test_rpcs_non_epsg4326(): with pytest.raises(RPCError): with rasterio.open('tests/data/RGB.byte.rpc.vrt') as src: @@ -2040,4 +2088,4 @@ def test_rpcs_non_epsg4326(): rpcs=src_rpcs, dst_crs="EPSG:4326", resampling=Resampling.nearest, - ) \ No newline at end of file + )