From fbc3347f736f2397a94cf2f2fddecd0e64b49537 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Wed, 7 Jan 2015 22:26:38 -0700 Subject: [PATCH] Logic and tests for merge without nodata. Closes #240 --- rasterio/rio/merge.py | 28 ++++++++++++-------- tests/test_rio_merge.py | 57 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/rasterio/rio/merge.py b/rasterio/rio/merge.py index 80322ebf..d6e647cb 100644 --- a/rasterio/rio/merge.py +++ b/rasterio/rio/merge.py @@ -36,31 +36,39 @@ def merge(ctx, files, driver): with rasterio.open(files[0]) as first: kwargs = first.meta kwargs['transform'] = kwargs.pop('affine') - dest = np.empty((first.count,) + first.shape, + dest = np.zeros((first.count,) + first.shape, dtype=first.dtypes[0]) + nodataval = 0.0 if os.path.exists(output): dst = rasterio.open(output, 'r+') - nodataval = dst.nodatavals[0] + nodataval = dst.nodatavals[0] or nodataval else: kwargs['driver'] == driver dst = rasterio.open(output, 'w', **kwargs) - nodataval = first.nodatavals[0] + nodataval = first.nodatavals[0] or nodataval - dest.fill(nodataval) + if nodataval: + dest.fill(nodataval) for fname in reversed(files): with rasterio.open(fname) as src: data = src.read() - np.copyto(dest, data, - where=np.logical_and( - dest==nodataval, data.mask==False)) + try: + where = np.logical_and( + dest==nodataval, data.mask==False) + except AttributeError: + where = dest==nodataval + np.copyto(dest, data, where=where) if dst.mode == 'r+': data = dst.read() - np.copyto(dest, data, - where=np.logical_and( - dest==nodataval, data.mask==False)) + try: + where = np.logical_and( + dest==nodataval, data.mask==False) + except AttributeError: + where = dest==nodataval + np.copyto(dest, data, where=where) dst.write(dest) dst.close() diff --git a/tests/test_rio_merge.py b/tests/test_rio_merge.py index 24ef19f8..03ed04aa 100644 --- a/tests/test_rio_merge.py +++ b/tests/test_rio_merge.py @@ -15,7 +15,7 @@ logging.basicConfig(stream=sys.stderr, level=logging.INFO) # Fixture to create test datasets within temporary directory @fixture(scope='function') -def test_data_dir(tmpdir): +def test_data_dir_1(tmpdir): kwargs = { "crs": {'init': 'epsg:4326'}, "transform": (-114, 0.2, 0, 46, 0, -0.1), @@ -28,12 +28,13 @@ def test_data_dir(tmpdir): } with rasterio.drivers(): - with rasterio.open(str(tmpdir.join('one.tif')), 'w', **kwargs) as dst: + + with rasterio.open(str(tmpdir.join('a.tif')), 'w', **kwargs) as dst: data = numpy.zeros((10, 10), dtype=rasterio.uint8) data[0:6, 0:6] = 255 dst.write_band(1, data) - with rasterio.open(str(tmpdir.join('two.tif')), 'w', **kwargs) as dst: + with rasterio.open(str(tmpdir.join('b.tif')), 'w', **kwargs) as dst: data = numpy.zeros((10, 10), dtype=rasterio.uint8) data[4:8, 4:8] = 254 dst.write_band(1, data) @@ -41,9 +42,53 @@ def test_data_dir(tmpdir): return tmpdir -def test_merge(test_data_dir): - outputname = str(test_data_dir.join('merged.tif')) - inputs = [str(x) for x in test_data_dir.listdir()] +@fixture(scope='function') +def test_data_dir_2(tmpdir): + kwargs = { + "crs": {'init': 'epsg:4326'}, + "transform": (-114, 0.2, 0, 46, 0, -0.1), + "count": 1, + "dtype": rasterio.uint8, + "driver": "GTiff", + "width": 10, + "height": 10 + } + + with rasterio.drivers(): + + with rasterio.open(str(tmpdir.join('a.tif')), 'w', **kwargs) as dst: + data = numpy.zeros((10, 10), dtype=rasterio.uint8) + data[0:6, 0:6] = 255 + dst.write_band(1, data) + + with rasterio.open(str(tmpdir.join('b.tif')), 'w', **kwargs) as dst: + data = numpy.zeros((10, 10), dtype=rasterio.uint8) + data[4:8, 4:8] = 254 + dst.write_band(1, data) + + return tmpdir + + +def test_merge_with_nodata(test_data_dir_1): + outputname = str(test_data_dir_1.join('merged.tif')) + inputs = [str(x) for x in test_data_dir_1.listdir()] + inputs.sort() + runner = CliRunner() + result = runner.invoke(merge, inputs + [outputname]) + assert result.exit_code == 0 + assert os.path.exists(outputname) + with rasterio.open(outputname) as out: + assert out.count == 1 + data = out.read_band(1, masked=False) + expected = numpy.zeros((10, 10), dtype=rasterio.uint8) + expected[0:6, 0:6] = 255 + expected[4:8, 4:8] = 254 + assert numpy.all(data == expected) + + +def test_merge_without_nodata(test_data_dir_2): + outputname = str(test_data_dir_2.join('merged.tif')) + inputs = [str(x) for x in test_data_dir_2.listdir()] inputs.sort() runner = CliRunner() result = runner.invoke(merge, inputs + [outputname])