Add resampling

This commit is contained in:
Sean Gillies 2023-05-27 19:22:32 -06:00
parent 8c96ff3f08
commit 14192e3368
2 changed files with 25 additions and 2 deletions

View File

@ -629,7 +629,9 @@ def virtual_merge(
src.indexes, src.colorinterp, src.block_shapes, src.dtypes src.indexes, src.colorinterp, src.block_shapes, src.dtypes
): ):
vrtrasterband = vrtdataset.find(f"VRTRasterBand[@band='{bidx}']") vrtrasterband = vrtdataset.find(f"VRTRasterBand[@band='{bidx}']")
complexsource = ET.SubElement(vrtrasterband, "ComplexSource") complexsource = ET.SubElement(
vrtrasterband, "ComplexSource", resampling=resampling.name
)
ET.SubElement( ET.SubElement(
complexsource, "SourceFilename", relativeToVRT="0", shared="0" complexsource, "SourceFilename", relativeToVRT="0", shared="0"
).text = _parse_path(src.name).as_vsi() ).text = _parse_path(src.name).as_vsi()
@ -685,7 +687,9 @@ def virtual_merge(
for bidx, ci, block_shape, dtype in zip( for bidx, ci, block_shape, dtype in zip(
src.indexes, src.colorinterp, src.block_shapes, src.dtypes src.indexes, src.colorinterp, src.block_shapes, src.dtypes
): ):
simplesource = ET.SubElement(vrtrasterband, "SimpleSource") simplesource = ET.SubElement(
vrtrasterband, "SimpleSource", resampling=resampling.name
)
ET.SubElement( ET.SubElement(
simplesource, "SourceFilename", relativeToVRT="0", shared="0" simplesource, "SourceFilename", relativeToVRT="0", shared="0"
).text = _parse_path(src.name).as_vsi() ).text = _parse_path(src.name).as_vsi()

View File

@ -1,6 +1,7 @@
"""Tests of rasterio.merge""" """Tests of rasterio.merge"""
from glob import glob from glob import glob
from xml.etree import ElementTree as ET
import boto3 import boto3
from hypothesis import given, settings from hypothesis import given, settings
@ -11,6 +12,7 @@ import pytest
import affine import affine
import rasterio import rasterio
from rasterio.crs import CRS from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.errors import RasterioError from rasterio.errors import RasterioError
from rasterio.merge import merge, virtual_merge from rasterio.merge import merge, virtual_merge
@ -128,6 +130,8 @@ def test_issue2202(dx, dy):
def test_virtual_merge(tmp_path): def test_virtual_merge(tmp_path):
"""Test.""" """Test."""
xml = virtual_merge(glob("tests/data/rgb?.tif")) xml = virtual_merge(glob("tests/data/rgb?.tif"))
assert b'resampling="nearest"' in xml
tmp_path.joinpath("test.vrt").write_bytes(xml) tmp_path.joinpath("test.vrt").write_bytes(xml)
with rasterio.open(tmp_path.joinpath("test.vrt")) as dataset: with rasterio.open(tmp_path.joinpath("test.vrt")) as dataset:
rgb = dataset.read() rgb = dataset.read()
@ -135,3 +139,18 @@ def test_virtual_merge(tmp_path):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.imshow(numpy.moveaxis(rgb, 0, -1)) plt.imshow(numpy.moveaxis(rgb, 0, -1))
plt.savefig("test_virtual_merge.png") plt.savefig("test_virtual_merge.png")
@pytest.mark.parametrize("resampling", [Resampling.nearest, Resampling.bilinear])
def test_virtual_merge_resampling(tmp_path, resampling):
"""Test."""
xml = virtual_merge(glob("tests/data/rgb?.tif"), resampling=resampling)
root = ET.fromstring(xml)
assert all(
elem.attrib["resampling"] == resampling.name
for elem in root.findall(".//ComplexSource")
)
assert all(
elem.attrib["resampling"] == resampling.name
for elem in root.findall(".//SimpleSource")
)