mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
Allow merge to open one dataset at a time (#1994)
And use this option in rio-merge. Resolves #1831
This commit is contained in:
parent
993bb47ff0
commit
01dbe0471b
@ -4,6 +4,7 @@ Changes
|
||||
1.1.6 (TBD)
|
||||
-----------
|
||||
|
||||
- Allow merge.merge() to open one dataset at a time (#1831).
|
||||
- Optimize CRS.__eq__() for CRS described by EPSG codes.
|
||||
- Fix bug in ParsedPath.is_remote() reported in #1967.
|
||||
- The reproject() method accepts objects that provide `__array__` in addition
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
"""Copy valid pixels from input files to an output file."""
|
||||
|
||||
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import rasterio
|
||||
from rasterio import windows
|
||||
from rasterio.compat import string_types
|
||||
from rasterio.transform import Affine
|
||||
|
||||
|
||||
@ -33,7 +35,7 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datasets: list of dataset objects opened in 'r' mode
|
||||
datasets : list of dataset objects opened in 'r' mode or filenames
|
||||
source datasets to be merged.
|
||||
bounds: tuple, optional
|
||||
Bounds of the output image (left, bottom, right, top).
|
||||
@ -94,23 +96,37 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
|
||||
out_transform: affine.Affine()
|
||||
Information for mapping pixel coordinates in `dest` to another
|
||||
coordinate system
|
||||
"""
|
||||
first = datasets[0]
|
||||
first_res = first.res
|
||||
nodataval = first.nodatavals[0]
|
||||
dt = first.dtypes[0]
|
||||
|
||||
"""
|
||||
if method not in MERGE_METHODS and not callable(method):
|
||||
raise ValueError('Unknown method {0}, must be one of {1} or callable'
|
||||
.format(method, MERGE_METHODS))
|
||||
|
||||
# Determine output band count
|
||||
if indexes is None:
|
||||
src_count = first.count
|
||||
elif isinstance(indexes, int):
|
||||
src_count = indexes
|
||||
# Create a dataset_opener object to use in several places in this function.
|
||||
if isinstance(datasets[0], string_types):
|
||||
dataset_opener = rasterio.open
|
||||
else:
|
||||
src_count = len(indexes)
|
||||
|
||||
@contextmanager
|
||||
def nullcontext(obj):
|
||||
try:
|
||||
yield obj
|
||||
finally:
|
||||
pass
|
||||
|
||||
dataset_opener = nullcontext
|
||||
|
||||
with dataset_opener(datasets[0]) as first:
|
||||
first_res = first.res
|
||||
nodataval = first.nodatavals[0]
|
||||
dt = first.dtypes[0]
|
||||
|
||||
if indexes is None:
|
||||
src_count = first.count
|
||||
elif isinstance(indexes, int):
|
||||
src_count = indexes
|
||||
else:
|
||||
src_count = len(indexes)
|
||||
|
||||
if not output_count:
|
||||
output_count = src_count
|
||||
@ -122,8 +138,9 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
|
||||
# scan input files
|
||||
xs = []
|
||||
ys = []
|
||||
for src in datasets:
|
||||
left, bottom, right, top = src.bounds
|
||||
for dataset in datasets:
|
||||
with dataset_opener(dataset) as src:
|
||||
left, bottom, right, top = src.bounds
|
||||
xs.extend([left, right])
|
||||
ys.extend([bottom, top])
|
||||
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
|
||||
@ -218,36 +235,43 @@ def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10
|
||||
else:
|
||||
raise ValueError(method)
|
||||
|
||||
for idx, src in enumerate(datasets):
|
||||
# Real World (tm) use of boundless reads.
|
||||
# This approach uses the maximum amount of memory to solve the
|
||||
# problem. Making it more efficient is a TODO.
|
||||
for idx, dataset in enumerate(datasets):
|
||||
with dataset_opener(dataset) as src:
|
||||
# Real World (tm) use of boundless reads.
|
||||
# This approach uses the maximum amount of memory to solve the
|
||||
# problem. Making it more efficient is a TODO.
|
||||
|
||||
# 1. Compute spatial intersection of destination and source
|
||||
src_w, src_s, src_e, src_n = src.bounds
|
||||
# 1. Compute spatial intersection of destination and source
|
||||
src_w, src_s, src_e, src_n = src.bounds
|
||||
|
||||
int_w = src_w if src_w > dst_w else dst_w
|
||||
int_s = src_s if src_s > dst_s else dst_s
|
||||
int_e = src_e if src_e < dst_e else dst_e
|
||||
int_n = src_n if src_n < dst_n else dst_n
|
||||
int_w = src_w if src_w > dst_w else dst_w
|
||||
int_s = src_s if src_s > dst_s else dst_s
|
||||
int_e = src_e if src_e < dst_e else dst_e
|
||||
int_n = src_n if src_n < dst_n else dst_n
|
||||
|
||||
# 2. Compute the source window
|
||||
src_window = windows.from_bounds(
|
||||
int_w, int_s, int_e, int_n, src.transform, precision=precision)
|
||||
logger.debug("Src %s window: %r", src.name, src_window)
|
||||
# 2. Compute the source window
|
||||
src_window = windows.from_bounds(
|
||||
int_w, int_s, int_e, int_n, src.transform, precision=precision
|
||||
)
|
||||
logger.debug("Src %s window: %r", src.name, src_window)
|
||||
|
||||
src_window = src_window.round_shape()
|
||||
src_window = src_window.round_shape()
|
||||
|
||||
# 3. Compute the destination window
|
||||
dst_window = windows.from_bounds(
|
||||
int_w, int_s, int_e, int_n, output_transform, precision=precision)
|
||||
# 3. Compute the destination window
|
||||
dst_window = windows.from_bounds(
|
||||
int_w, int_s, int_e, int_n, output_transform, precision=precision
|
||||
)
|
||||
|
||||
# 4. Read data in source window into temp
|
||||
trows, tcols = (
|
||||
int(round(dst_window.height)), int(round(dst_window.width)))
|
||||
temp_shape = (src_count, trows, tcols)
|
||||
temp = src.read(out_shape=temp_shape, window=src_window,
|
||||
boundless=False, masked=True, indexes=indexes)
|
||||
# 4. Read data in source window into temp
|
||||
trows, tcols = (int(round(dst_window.height)), int(round(dst_window.width)))
|
||||
temp_shape = (src_count, trows, tcols)
|
||||
temp = src.read(
|
||||
out_shape=temp_shape,
|
||||
window=src_window,
|
||||
boundless=False,
|
||||
masked=True,
|
||||
indexes=indexes,
|
||||
)
|
||||
|
||||
# 5. Copy elements of temp into dest
|
||||
roff, coff = (
|
||||
|
||||
@ -49,30 +49,35 @@ def merge(ctx, files, output, driver, bounds, res, nodata, bidx, overwrite,
|
||||
output, files = resolve_inout(
|
||||
files=files, output=output, overwrite=overwrite)
|
||||
|
||||
with ctx.obj['env']:
|
||||
datasets = [rasterio.open(f) for f in files]
|
||||
dest, output_transform = merge_tool(datasets, bounds=bounds, res=res,
|
||||
nodata=nodata, precision=precision,
|
||||
indexes=(bidx or None))
|
||||
with ctx.obj["env"]:
|
||||
dest, output_transform = merge_tool(
|
||||
files,
|
||||
bounds=bounds,
|
||||
res=res,
|
||||
nodata=nodata,
|
||||
precision=precision,
|
||||
indexes=(bidx or None),
|
||||
)
|
||||
|
||||
profile = datasets[0].profile
|
||||
profile['transform'] = output_transform
|
||||
profile['height'] = dest.shape[1]
|
||||
profile['width'] = dest.shape[2]
|
||||
profile['driver'] = driver
|
||||
profile['count'] = dest.shape[0]
|
||||
with rasterio.open(files[0]) as first:
|
||||
profile = first.profile
|
||||
profile["transform"] = output_transform
|
||||
profile["height"] = dest.shape[1]
|
||||
profile["width"] = dest.shape[2]
|
||||
profile["driver"] = driver
|
||||
profile["count"] = dest.shape[0]
|
||||
|
||||
if nodata is not None:
|
||||
profile['nodata'] = nodata
|
||||
if nodata is not None:
|
||||
profile["nodata"] = nodata
|
||||
|
||||
profile.update(**creation_options)
|
||||
profile.update(**creation_options)
|
||||
|
||||
with rasterio.open(output, 'w', **profile) as dst:
|
||||
dst.write(dest)
|
||||
with rasterio.open(output, "w", **profile) as dst:
|
||||
dst.write(dest)
|
||||
|
||||
# uses the colormap in the first input raster.
|
||||
try:
|
||||
colormap = datasets[0].colormap(1)
|
||||
dst.write_colormap(1, colormap)
|
||||
except ValueError:
|
||||
pass
|
||||
# uses the colormap in the first input raster.
|
||||
try:
|
||||
colormap = first.colormap(1)
|
||||
dst.write_colormap(1, colormap)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@ -556,3 +556,9 @@ def test_merge_precision(tmpdir, precision):
|
||||
result = runner.invoke(main_group, ["merge", "-f", "AAIGrid"] + precision + inputs + [outputname])
|
||||
assert result.exit_code == 0
|
||||
assert open(outputname).read() == textwrap.dedent(expected)
|
||||
|
||||
|
||||
def test_merge_filenames(tiffs):
|
||||
inputs = [str(x) for x in tiffs.listdir()]
|
||||
inputs.sort()
|
||||
merge(inputs, res=2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user