diff --git a/rasterio/merge.py b/rasterio/merge.py index 0b5c316d..09f37a57 100644 --- a/rasterio/merge.py +++ b/rasterio/merge.py @@ -629,7 +629,9 @@ def virtual_merge( src.indexes, src.colorinterp, src.block_shapes, src.dtypes ): vrtrasterband = vrtdataset.find(f"VRTRasterBand[@band='{bidx}']") - complexsource = ET.SubElement(vrtrasterband, "ComplexSource") + complexsource = ET.SubElement( + vrtrasterband, "ComplexSource", resampling=resampling.name + ) ET.SubElement( complexsource, "SourceFilename", relativeToVRT="0", shared="0" ).text = _parse_path(src.name).as_vsi() @@ -685,7 +687,9 @@ def virtual_merge( for bidx, ci, block_shape, dtype in zip( src.indexes, src.colorinterp, src.block_shapes, src.dtypes ): - simplesource = ET.SubElement(vrtrasterband, "SimpleSource") + simplesource = ET.SubElement( + vrtrasterband, "SimpleSource", resampling=resampling.name + ) ET.SubElement( simplesource, "SourceFilename", relativeToVRT="0", shared="0" ).text = _parse_path(src.name).as_vsi() diff --git a/tests/test_merge.py b/tests/test_merge.py index 24dffe9b..0fe0b845 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,6 +1,7 @@ """Tests of rasterio.merge""" from glob import glob +from xml.etree import ElementTree as ET import boto3 from hypothesis import given, settings @@ -11,6 +12,7 @@ import pytest import affine import rasterio from rasterio.crs import CRS +from rasterio.enums import Resampling from rasterio.errors import RasterioError from rasterio.merge import merge, virtual_merge @@ -128,6 +130,8 @@ def test_issue2202(dx, dy): def test_virtual_merge(tmp_path): """Test.""" xml = virtual_merge(glob("tests/data/rgb?.tif")) + assert b'resampling="nearest"' in xml + tmp_path.joinpath("test.vrt").write_bytes(xml) with rasterio.open(tmp_path.joinpath("test.vrt")) as dataset: rgb = dataset.read() @@ -135,3 +139,18 @@ def test_virtual_merge(tmp_path): import matplotlib.pyplot as plt plt.imshow(numpy.moveaxis(rgb, 0, -1)) plt.savefig("test_virtual_merge.png") + + +@pytest.mark.parametrize("resampling", [Resampling.nearest, Resampling.bilinear]) +def test_virtual_merge_resampling(tmp_path, resampling): + """Test.""" + xml = virtual_merge(glob("tests/data/rgb?.tif"), resampling=resampling) + root = ET.fromstring(xml) + assert all( + elem.attrib["resampling"] == resampling.name + for elem in root.findall(".//ComplexSource") + ) + assert all( + elem.attrib["resampling"] == resampling.name + for elem in root.findall(".//SimpleSource") + )