Merge branch 'master' of github.com:rasterio/rasterio

This commit is contained in:
Sean Gillies 2022-04-11 15:55:44 -06:00
commit 1dc4b415eb
5 changed files with 81 additions and 18 deletions

View File

@ -572,7 +572,10 @@ def _reproject(
try:
exc_wrap_int(oWarper.Initialize(psWOptions))
rows, cols = destination.shape[-2:]
if isinstance(destination, tuple):
rows, cols = destination[3]
else:
rows, cols = destination.shape[-2:]
log.debug(
"Chunk and warp window: %d, %d, %d, %d.",

View File

@ -267,7 +267,7 @@ class ZipMemoryFile(MemoryFile):
def get_writer_for_driver(driver):
"""Return the writer class appropriate for the specified driver."""
if not driver:
raise ValueError("'driver' is required to write dataset.")
raise ValueError("'driver' is required to read/write dataset.")
cls = None
if driver_can_create(driver):
cls = DatasetWriter

View File

@ -1,11 +1,21 @@
# Workaround for issue #378. A pure Python generator.
import numpy
import numpy as np
import rasterio._loading
with rasterio._loading.add_gdal_dll_directories():
from rasterio.enums import MaskFlags
from rasterio.windows import Window
from rasterio.transform import rowcol
from itertools import zip_longest
def _grouper(iterable, n, fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
# from itertools recipes
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
def sample_gen(dataset, xy, indexes=None, masked=False):
@ -31,27 +41,29 @@ def sample_gen(dataset, xy, indexes=None, masked=False):
those indexes.
"""
index = dataset.index
dt = dataset.transform
read = dataset.read
height = dataset.height
width = dataset.width
if indexes is None:
indexes = dataset.indexes
elif isinstance(indexes, int):
indexes = [indexes]
for x, y in xy:
nodata = np.full(len(indexes), (dataset.nodata or 0), dtype=dataset.dtypes[0])
if masked:
# Masks for masked arrays are inverted (False means valid)
mask = [MaskFlags.all_valid not in dataset.mask_flag_enums[i-1] for i in indexes]
nodata = np.ma.array(nodata, mask=mask)
row_off, col_off = index(x, y)
for pts in _grouper(xy, 256):
pts = zip(*filter(None, pts))
if row_off < 0 or col_off < 0 or row_off >= dataset.height or col_off >= dataset.width:
data = numpy.ones((len(indexes),), dtype=dataset.dtypes[0]) * (dataset.nodata or 0)
if masked:
mask = [False if MaskFlags.all_valid in dataset.mask_flag_enums[i - 1] else True for i in indexes]
yield numpy.ma.array(data, mask=mask)
for row_off, col_off in zip(*rowcol(dt, *pts)):
if row_off < 0 or col_off < 0 or row_off >= height or col_off >= width:
yield nodata
else:
yield data
else:
window = Window(col_off, row_off, 1, 1)
data = read(indexes, window=window, masked=masked)
yield data[:, 0, 0]
window = Window(col_off, row_off, 1, 1)
data = read(indexes, window=window, masked=masked)
yield data[:, 0, 0]

Binary file not shown.

View File

@ -2030,6 +2030,54 @@ def test_reproject_error_propagation(http_error_server, caplog):
assert len([rec for rec in caplog.records if "Retrying again" in rec.message]) == 2
def test_reproject_to_specified_output_bands():
"""
Reproject multiple input rasters to a single output raster, joining their bands.
In this example, we concatenate a RGB raster with a mocked NIR image, while reprojecting and resampling both.
"""
with rasterio.open('tests/data/rgb1.tif') as src_rgb, \
rasterio.open('tests/data/rgb1_fake_nir_epsg3857.tif') as src_nir:
output_crs = CRS.from_epsg(4326)
output_transform, output_width, output_height = calculate_default_transform(
src_rgb.crs,
output_crs,
src_rgb.width,
src_rgb.height,
*src_rgb.bounds)
with rasterio.MemoryFile() as mem:
with mem.open(
driver="GTiff",
width=output_width,
height=output_height,
count=5,
transform=output_transform,
dtype="uint8",
nodata=0,
crs=output_crs,
) as out: # type: rasterio.io.DatasetWriter
reproject(
rasterio.band(src_rgb, src_rgb.indexes),
rasterio.band(out, src_rgb.indexes),
resampling=Resampling.nearest,
)
reproject(
rasterio.band(src_nir, 1),
rasterio.band(out, 4),
resampling=Resampling.nearest,
)
with mem.open() as out: # type: rasterio.DatasetReader
assert out.count == 5
for i in range(1, 5):
band_data = out.read(i)
assert (band_data != 0).any()
band_data = out.read(5)
assert (band_data == 0).all()
def test_rpcs_non_epsg4326():
with pytest.raises(RPCError):
with rasterio.open('tests/data/RGB.byte.rpc.vrt') as src:
@ -2040,4 +2088,4 @@ def test_rpcs_non_epsg4326():
rpcs=src_rpcs,
dst_crs="EPSG:4326",
resampling=Resampling.nearest,
)
)