From fe02e31c48d5dcfeabc3feed2861a28d87af9400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Santos?= Date: Sun, 24 Apr 2022 03:47:54 +0100 Subject: [PATCH] BUG: Fix output file dtype in merge (#2450) * Fix output file dtype in merge Signed-off-by: Joao * Fix missing dtype in merge click call * Update driver in creation_options in merge call --- rasterio/merge.py | 1 + rasterio/rio/merge.py | 21 +++++++++++++++++++-- tests/test_rio_merge.py | 15 +++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/rasterio/merge.py b/rasterio/merge.py index d49768be..dcd5bed5 100644 --- a/rasterio/merge.py +++ b/rasterio/merge.py @@ -261,6 +261,7 @@ def merge( out_profile["height"] = output_height out_profile["width"] = output_width out_profile["count"] = output_count + out_profile["dtype"] = dt if nodata is not None: out_profile["nodata"] = nodata diff --git a/rasterio/rio/merge.py b/rasterio/rio/merge.py index 82df93a7..d21eb0b9 100644 --- a/rasterio/rio/merge.py +++ b/rasterio/rio/merge.py @@ -29,6 +29,7 @@ def deprecated_precision(*args): default='nearest', help="Resampling method.", show_default=True) @options.nodata_opt +@options.dtype_opt @options.bidx_mult_opt @options.overwrite_opt @click.option( @@ -40,8 +41,21 @@ def deprecated_precision(*args): ) @options.creation_options @click.pass_context -def merge(ctx, files, output, driver, bounds, res, resampling, - nodata, bidx, overwrite, precision, creation_options): +def merge( + ctx, + files, + output, + driver, + bounds, + res, + resampling, + nodata, + dtype, + bidx, + overwrite, + precision, + creation_options, +): """Copy valid pixels from input files to an output file. All files must have the same number of bands, data type, and @@ -68,6 +82,8 @@ def merge(ctx, files, output, driver, bounds, res, resampling, files=files, output=output, overwrite=overwrite) resampling = Resampling[resampling] + if driver: + creation_options.update(driver=driver) with ctx.obj["env"]: merge_tool( @@ -75,6 +91,7 @@ def merge(ctx, files, output, driver, bounds, res, resampling, bounds=bounds, res=res, nodata=nodata, + dtype=dtype, indexes=(bidx or None), resampling=resampling, dst_path=output, diff --git a/tests/test_rio_merge.py b/tests/test_rio_merge.py index 98835df7..89a56dda 100644 --- a/tests/test_rio_merge.py +++ b/tests/test_rio_merge.py @@ -103,6 +103,21 @@ def test_data_dir_3(tmpdir): return tmpdir +def test_rio_merge_dtype(test_data_dir_1, runner): + outputname = str(test_data_dir_1.join("merged.tif")) + inputs = [str(x) for x in test_data_dir_1.listdir()] + inputs.sort() + + result = runner.invoke( + main_group, ["merge", "--dtype", "uint16"] + inputs + [outputname] + ) + assert result.exit_code == 0 + assert os.path.exists(outputname) + + with rasterio.open(outputname) as out: + assert all(dt == "uint16" for dt in out.dtypes) + + def test_merge_with_colormap(test_data_dir_1, runner): outputname = str(test_data_dir_1.join('merged.tif')) inputs = [str(x) for x in test_data_dir_1.listdir()]