rasterio/tests/test_merge.py
2023-05-27 19:22:32 -06:00

157 lines
5.1 KiB
Python

"""Tests of rasterio.merge"""
from glob import glob
from xml.etree import ElementTree as ET
import boto3
from hypothesis import given, settings
from hypothesis.strategies import floats
import numpy
import pytest
import affine
import rasterio
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.errors import RasterioError
from rasterio.merge import merge, virtual_merge
# Non-coincident datasets test fixture.
# Three overlapping GeoTIFFs, two to the NW and one to the SE.
@pytest.fixture(scope="function")
def test_data_dir_overlapping(tmp_path):
kwargs = {
"crs": "EPSG:4326",
"transform": affine.Affine(0.2, 0, -114, 0, -0.2, 46),
"count": 1,
"dtype": rasterio.uint8,
"driver": "GTiff",
"width": 10,
"height": 10,
"nodata": 0,
}
with rasterio.open(tmp_path.joinpath("nw1.tif"), "w", **kwargs) as dst:
data = numpy.ones((10, 10), dtype=rasterio.uint8)
dst.write(data, indexes=1)
with rasterio.open(tmp_path.joinpath("nw3.tif"), "w", **kwargs) as dst:
data = numpy.ones((10, 10), dtype=rasterio.uint8) * 3
dst.write(data, indexes=1)
kwargs["transform"] = affine.Affine(0.2, 0, -113, 0, -0.2, 45)
with rasterio.open(tmp_path.joinpath("se.tif"), "w", **kwargs) as dst:
data = numpy.ones((10, 10), dtype=rasterio.uint8) * 2
dst.write(data, indexes=1)
return tmp_path
def test_different_crs(test_data_dir_overlapping):
inputs = [x.name for x in test_data_dir_overlapping.iterdir()]
# Create new raster with different crs
with rasterio.open(test_data_dir_overlapping.joinpath(inputs[-1])) as ds_src:
kwds = ds_src.profile
kwds['crs'] = CRS.from_epsg(3499)
with rasterio.open(test_data_dir_overlapping.joinpath("new.tif"), 'w', **kwds) as ds_out:
ds_out.write(ds_src.read())
with pytest.raises(RasterioError):
result = merge(list(test_data_dir_overlapping.iterdir()))
@pytest.mark.parametrize(
"method,value",
[("first", 1), ("last", 2), ("min", 1), ("max", 3), ("sum", 6), ("count", 3)],
)
def test_merge_method(test_data_dir_overlapping, method, value):
"""Merge method produces expected values in intersection"""
inputs = sorted(list(test_data_dir_overlapping.iterdir())) # nw is first.
datasets = [rasterio.open(x) for x in inputs]
output_count = 1
arr, _ = merge(
datasets, output_count=output_count, method=method, dtype=numpy.uint64
)
numpy.testing.assert_array_equal(arr[:, 5:10, 5:10], value)
def test_issue2163():
"""Demonstrate fix for issue 2163"""
with rasterio.open("tests/data/float_raster_with_nodata.tif") as src:
data = src.read()
result, transform = merge([src])
assert numpy.allclose(data, result[:, : data.shape[1], : data.shape[2]])
def test_unsafe_casting():
"""Demonstrate fix for issue 2179"""
with rasterio.open("tests/data/float_raster_with_nodata.tif") as src:
result, transform = merge([src], dtype="uint8", nodata=0.0)
assert not result.any() # this is why it's called "unsafe".
@pytest.mark.skipif(
not (boto3.Session().get_credentials()),
reason="S3 raster access requires credentials",
)
@pytest.mark.network
@pytest.mark.slow
@settings(deadline=None, max_examples=5)
@given(
dx=floats(min_value=-0.05, max_value=0.05),
dy=floats(min_value=-0.05, max_value=0.05),
)
def test_issue2202(dx, dy):
import rasterio.merge
from shapely import wkt
from shapely.affinity import translate
aoi = wkt.loads(
r"POLYGON((11.09 47.94, 11.06 48.01, 11.12 48.11, 11.18 48.11, 11.18 47.94, 11.09 47.94))"
)
aoi = translate(aoi, dx, dy)
with rasterio.Env(AWS_NO_SIGN_REQUEST=True,):
ds = [
rasterio.open(i)
for i in [
"/vsis3/copernicus-dem-30m/Copernicus_DSM_COG_10_N47_00_E011_00_DEM/Copernicus_DSM_COG_10_N47_00_E011_00_DEM.tif",
"/vsis3/copernicus-dem-30m/Copernicus_DSM_COG_10_N48_00_E011_00_DEM/Copernicus_DSM_COG_10_N48_00_E011_00_DEM.tif",
]
]
aux_array, aux_transform = rasterio.merge.merge(datasets=ds, bounds=aoi.bounds)
from rasterio.plot import show
show(aux_array)
def test_virtual_merge(tmp_path):
"""Test."""
xml = virtual_merge(glob("tests/data/rgb?.tif"))
assert b'resampling="nearest"' in xml
tmp_path.joinpath("test.vrt").write_bytes(xml)
with rasterio.open(tmp_path.joinpath("test.vrt")) as dataset:
rgb = dataset.read()
import matplotlib.pyplot as plt
plt.imshow(numpy.moveaxis(rgb, 0, -1))
plt.savefig("test_virtual_merge.png")
@pytest.mark.parametrize("resampling", [Resampling.nearest, Resampling.bilinear])
def test_virtual_merge_resampling(tmp_path, resampling):
"""Test."""
xml = virtual_merge(glob("tests/data/rgb?.tif"), resampling=resampling)
root = ET.fromstring(xml)
assert all(
elem.attrib["resampling"] == resampling.name
for elem in root.findall(".//ComplexSource")
)
assert all(
elem.attrib["resampling"] == resampling.name
for elem in root.findall(".//SimpleSource")
)