Allow merge to open one dataset at a time (#1994)

And use this option in rio-merge.

Resolves #1831
This commit is contained in:
Sean Gillies 2020-09-09 15:38:52 -06:00 committed by GitHub
parent 993bb47ff0
commit 01dbe0471b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 61 deletions

View File

@ -4,6 +4,7 @@ Changes
1.1.6 (TBD)
-----------
- Allow merge.merge() to open one dataset at a time (#1831).
- Optimize CRS.__eq__() for CRS described by EPSG codes.
- Fix bug in ParsedPath.is_remote() reported in #1967.
- The reproject() method accepts objects that provide `__array__` in addition

View File

@ -1,13 +1,15 @@
"""Copy valid pixels from input files to an output file."""
from contextlib import contextmanager
import logging
import math
import warnings
import numpy as np
import rasterio
from rasterio import windows
from rasterio.compat import string_types
from rasterio.transform import Affine
@ -33,7 +35,7 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
Parameters
----------
datasets: list of dataset objects opened in 'r' mode
datasets : list of dataset objects opened in 'r' mode or filenames
source datasets to be merged.
bounds: tuple, optional
Bounds of the output image (left, bottom, right, top).
@ -94,23 +96,37 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
out_transform: affine.Affine()
Information for mapping pixel coordinates in `dest` to another
coordinate system
"""
first = datasets[0]
first_res = first.res
nodataval = first.nodatavals[0]
dt = first.dtypes[0]
"""
if method not in MERGE_METHODS and not callable(method):
raise ValueError('Unknown method {0}, must be one of {1} or callable'
.format(method, MERGE_METHODS))
# Determine output band count
if indexes is None:
src_count = first.count
elif isinstance(indexes, int):
src_count = indexes
# Create a dataset_opener object to use in several places in this function.
if isinstance(datasets[0], string_types):
dataset_opener = rasterio.open
else:
src_count = len(indexes)
@contextmanager
def nullcontext(obj):
try:
yield obj
finally:
pass
dataset_opener = nullcontext
with dataset_opener(datasets[0]) as first:
first_res = first.res
nodataval = first.nodatavals[0]
dt = first.dtypes[0]
if indexes is None:
src_count = first.count
elif isinstance(indexes, int):
src_count = indexes
else:
src_count = len(indexes)
if not output_count:
output_count = src_count
@ -122,8 +138,9 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
# scan input files
xs = []
ys = []
for src in datasets:
left, bottom, right, top = src.bounds
for dataset in datasets:
with dataset_opener(dataset) as src:
left, bottom, right, top = src.bounds
xs.extend([left, right])
ys.extend([bottom, top])
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
@ -218,36 +235,43 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
else:
raise ValueError(method)
for idx, src in enumerate(datasets):
# Real World (tm) use of boundless reads.
# This approach uses the maximum amount of memory to solve the
# problem. Making it more efficient is a TODO.
for idx, dataset in enumerate(datasets):
with dataset_opener(dataset) as src:
# Real World (tm) use of boundless reads.
# This approach uses the maximum amount of memory to solve the
# problem. Making it more efficient is a TODO.
# 1. Compute spatial intersection of destination and source
src_w, src_s, src_e, src_n = src.bounds
# 1. Compute spatial intersection of destination and source
src_w, src_s, src_e, src_n = src.bounds
int_w = src_w if src_w > dst_w else dst_w
int_s = src_s if src_s > dst_s else dst_s
int_e = src_e if src_e < dst_e else dst_e
int_n = src_n if src_n < dst_n else dst_n
int_w = src_w if src_w > dst_w else dst_w
int_s = src_s if src_s > dst_s else dst_s
int_e = src_e if src_e < dst_e else dst_e
int_n = src_n if src_n < dst_n else dst_n
# 2. Compute the source window
src_window = windows.from_bounds(
int_w, int_s, int_e, int_n, src.transform, precision=precision)
logger.debug("Src %s window: %r", src.name, src_window)
# 2. Compute the source window
src_window = windows.from_bounds(
int_w, int_s, int_e, int_n, src.transform, precision=precision
)
logger.debug("Src %s window: %r", src.name, src_window)
src_window = src_window.round_shape()
src_window = src_window.round_shape()
# 3. Compute the destination window
dst_window = windows.from_bounds(
int_w, int_s, int_e, int_n, output_transform, precision=precision)
# 3. Compute the destination window
dst_window = windows.from_bounds(
int_w, int_s, int_e, int_n, output_transform, precision=precision
)
# 4. Read data in source window into temp
trows, tcols = (
int(round(dst_window.height)), int(round(dst_window.width)))
temp_shape = (src_count, trows, tcols)
temp = src.read(out_shape=temp_shape, window=src_window,
boundless=False, masked=True, indexes=indexes)
# 4. Read data in source window into temp
trows, tcols = (int(round(dst_window.height)), int(round(dst_window.width)))
temp_shape = (src_count, trows, tcols)
temp = src.read(
out_shape=temp_shape,
window=src_window,
boundless=False,
masked=True,
indexes=indexes,
)
# 5. Copy elements of temp into dest
roff, coff = (

View File

@ -49,30 +49,35 @@ def merge(ctx, files, output, driver, bounds, res, nodata, bidx, overwrite,
output, files = resolve_inout(
files=files, output=output, overwrite=overwrite)
with ctx.obj['env']:
datasets = [rasterio.open(f) for f in files]
dest, output_transform = merge_tool(datasets, bounds=bounds, res=res,
nodata=nodata, precision=precision,
indexes=(bidx or None))
with ctx.obj["env"]:
dest, output_transform = merge_tool(
files,
bounds=bounds,
res=res,
nodata=nodata,
precision=precision,
indexes=(bidx or None),
)
profile = datasets[0].profile
profile['transform'] = output_transform
profile['height'] = dest.shape[1]
profile['width'] = dest.shape[2]
profile['driver'] = driver
profile['count'] = dest.shape[0]
with rasterio.open(files[0]) as first:
profile = first.profile
profile["transform"] = output_transform
profile["height"] = dest.shape[1]
profile["width"] = dest.shape[2]
profile["driver"] = driver
profile["count"] = dest.shape[0]
if nodata is not None:
profile['nodata'] = nodata
if nodata is not None:
profile["nodata"] = nodata
profile.update(**creation_options)
profile.update(**creation_options)
with rasterio.open(output, 'w', **profile) as dst:
dst.write(dest)
with rasterio.open(output, "w", **profile) as dst:
dst.write(dest)
# uses the colormap in the first input raster.
try:
colormap = datasets[0].colormap(1)
dst.write_colormap(1, colormap)
except ValueError:
pass
# uses the colormap in the first input raster.
try:
colormap = first.colormap(1)
dst.write_colormap(1, colormap)
except ValueError:
pass

View File

@ -556,3 +556,9 @@ def test_merge_precision(tmpdir, precision):
result = runner.invoke(main_group, ["merge", "-f", "AAIGrid"] + precision + inputs + [outputname])
assert result.exit_code == 0
assert open(outputname).read() == textwrap.dedent(expected)
def test_merge_filenames(tiffs):
inputs = [str(x) for x in tiffs.listdir()]
inputs.sort()
merge(inputs, res=2)