mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
157 lines
5.1 KiB
Python
157 lines
5.1 KiB
Python
"""Tests of rasterio.merge"""
|
|
|
|
from glob import glob
|
|
from xml.etree import ElementTree as ET
|
|
|
|
import boto3
|
|
from hypothesis import given, settings
|
|
from hypothesis.strategies import floats
|
|
import numpy
|
|
import pytest
|
|
|
|
import affine
|
|
import rasterio
|
|
from rasterio.crs import CRS
|
|
from rasterio.enums import Resampling
|
|
from rasterio.errors import RasterioError
|
|
from rasterio.merge import merge, virtual_merge
|
|
|
|
|
|
# Non-coincident datasets test fixture.
|
|
# Three overlapping GeoTIFFs, two to the NW and one to the SE.
|
|
@pytest.fixture(scope="function")
|
|
def test_data_dir_overlapping(tmp_path):
|
|
kwargs = {
|
|
"crs": "EPSG:4326",
|
|
"transform": affine.Affine(0.2, 0, -114, 0, -0.2, 46),
|
|
"count": 1,
|
|
"dtype": rasterio.uint8,
|
|
"driver": "GTiff",
|
|
"width": 10,
|
|
"height": 10,
|
|
"nodata": 0,
|
|
}
|
|
|
|
with rasterio.open(tmp_path.joinpath("nw1.tif"), "w", **kwargs) as dst:
|
|
data = numpy.ones((10, 10), dtype=rasterio.uint8)
|
|
dst.write(data, indexes=1)
|
|
|
|
with rasterio.open(tmp_path.joinpath("nw3.tif"), "w", **kwargs) as dst:
|
|
data = numpy.ones((10, 10), dtype=rasterio.uint8) * 3
|
|
dst.write(data, indexes=1)
|
|
|
|
kwargs["transform"] = affine.Affine(0.2, 0, -113, 0, -0.2, 45)
|
|
|
|
with rasterio.open(tmp_path.joinpath("se.tif"), "w", **kwargs) as dst:
|
|
data = numpy.ones((10, 10), dtype=rasterio.uint8) * 2
|
|
dst.write(data, indexes=1)
|
|
|
|
return tmp_path
|
|
|
|
|
|
def test_different_crs(test_data_dir_overlapping):
|
|
inputs = [x.name for x in test_data_dir_overlapping.iterdir()]
|
|
|
|
# Create new raster with different crs
|
|
with rasterio.open(test_data_dir_overlapping.joinpath(inputs[-1])) as ds_src:
|
|
kwds = ds_src.profile
|
|
kwds['crs'] = CRS.from_epsg(3499)
|
|
with rasterio.open(test_data_dir_overlapping.joinpath("new.tif"), 'w', **kwds) as ds_out:
|
|
ds_out.write(ds_src.read())
|
|
with pytest.raises(RasterioError):
|
|
result = merge(list(test_data_dir_overlapping.iterdir()))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"method,value",
|
|
[("first", 1), ("last", 2), ("min", 1), ("max", 3), ("sum", 6), ("count", 3)],
|
|
)
|
|
def test_merge_method(test_data_dir_overlapping, method, value):
|
|
"""Merge method produces expected values in intersection"""
|
|
inputs = sorted(list(test_data_dir_overlapping.iterdir())) # nw is first.
|
|
datasets = [rasterio.open(x) for x in inputs]
|
|
output_count = 1
|
|
arr, _ = merge(
|
|
datasets, output_count=output_count, method=method, dtype=numpy.uint64
|
|
)
|
|
numpy.testing.assert_array_equal(arr[:, 5:10, 5:10], value)
|
|
|
|
|
|
def test_issue2163():
|
|
"""Demonstrate fix for issue 2163"""
|
|
with rasterio.open("tests/data/float_raster_with_nodata.tif") as src:
|
|
data = src.read()
|
|
result, transform = merge([src])
|
|
assert numpy.allclose(data, result[:, : data.shape[1], : data.shape[2]])
|
|
|
|
|
|
def test_unsafe_casting():
|
|
"""Demonstrate fix for issue 2179"""
|
|
with rasterio.open("tests/data/float_raster_with_nodata.tif") as src:
|
|
result, transform = merge([src], dtype="uint8", nodata=0.0)
|
|
assert not result.any() # this is why it's called "unsafe".
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (boto3.Session().get_credentials()),
|
|
reason="S3 raster access requires credentials",
|
|
)
|
|
@pytest.mark.network
|
|
@pytest.mark.slow
|
|
@settings(deadline=None, max_examples=5)
|
|
@given(
|
|
dx=floats(min_value=-0.05, max_value=0.05),
|
|
dy=floats(min_value=-0.05, max_value=0.05),
|
|
)
|
|
def test_issue2202(dx, dy):
|
|
import rasterio.merge
|
|
from shapely import wkt
|
|
from shapely.affinity import translate
|
|
|
|
aoi = wkt.loads(
|
|
r"POLYGON((11.09 47.94, 11.06 48.01, 11.12 48.11, 11.18 48.11, 11.18 47.94, 11.09 47.94))"
|
|
)
|
|
aoi = translate(aoi, dx, dy)
|
|
|
|
with rasterio.Env(AWS_NO_SIGN_REQUEST=True,):
|
|
ds = [
|
|
rasterio.open(i)
|
|
for i in [
|
|
"/vsis3/copernicus-dem-30m/Copernicus_DSM_COG_10_N47_00_E011_00_DEM/Copernicus_DSM_COG_10_N47_00_E011_00_DEM.tif",
|
|
"/vsis3/copernicus-dem-30m/Copernicus_DSM_COG_10_N48_00_E011_00_DEM/Copernicus_DSM_COG_10_N48_00_E011_00_DEM.tif",
|
|
]
|
|
]
|
|
aux_array, aux_transform = rasterio.merge.merge(datasets=ds, bounds=aoi.bounds)
|
|
from rasterio.plot import show
|
|
|
|
show(aux_array)
|
|
|
|
|
|
def test_virtual_merge(tmp_path):
|
|
"""Test."""
|
|
xml = virtual_merge(glob("tests/data/rgb?.tif"))
|
|
assert b'resampling="nearest"' in xml
|
|
|
|
tmp_path.joinpath("test.vrt").write_bytes(xml)
|
|
with rasterio.open(tmp_path.joinpath("test.vrt")) as dataset:
|
|
rgb = dataset.read()
|
|
|
|
import matplotlib.pyplot as plt
|
|
plt.imshow(numpy.moveaxis(rgb, 0, -1))
|
|
plt.savefig("test_virtual_merge.png")
|
|
|
|
|
|
@pytest.mark.parametrize("resampling", [Resampling.nearest, Resampling.bilinear])
|
|
def test_virtual_merge_resampling(tmp_path, resampling):
|
|
"""Test."""
|
|
xml = virtual_merge(glob("tests/data/rgb?.tif"), resampling=resampling)
|
|
root = ET.fromstring(xml)
|
|
assert all(
|
|
elem.attrib["resampling"] == resampling.name
|
|
for elem in root.findall(".//ComplexSource")
|
|
)
|
|
assert all(
|
|
elem.attrib["resampling"] == resampling.name
|
|
for elem in root.findall(".//SimpleSource")
|
|
)
|