From 66ae338d973d4c657ea933bcbd5b8291395f25f5 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Sun, 8 Feb 2015 10:57:39 -0700 Subject: [PATCH] Preserve masks when re-casting to float64. Also test reducing a 3-band file to 1. --- rasterio/rio/calc.py | 62 ++++++++++++++++++++++-------------------- tests/test_rio_calc.py | 20 +++++++++++++- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/rasterio/rio/calc.py b/rasterio/rio/calc.py index a074ceb5..ca9d1785 100644 --- a/rasterio/rio/calc.py +++ b/rasterio/rio/calc.py @@ -18,17 +18,27 @@ from rasterio.rio.cli import cli @click.argument('command') @files_inout_arg @click.option('--dtype', - type=click.Choice( - ['uint8', 'uint16', 'int16', 'float32', 'float64']), - help="Output data type.") + type=click.Choice([ + 'ubyte', 'uint8', 'uint16', 'int16', 'uint32', + 'int32', 'float32', 'float64']), + default='float64', + help="Output data type (default: float64).") @click.pass_context def calc(ctx, command, files, dtype): """A raster data calculator - Applies one or more commands to a set of input datasets and writes the - results to a new dataset. + Applies one or more commands to a set of input datasets and writes + the results to a new dataset. - Command syntax is a work in progress. + Command syntax is a work in progress. Currently: + + * {n} represents the n-th input dataset (a 3-D array) + * {n,m} represents the m-th band of the n-th dataset (a 2-D array). + * + + + Calculations on all bands of a + dataset can be specified using """ import numpy as np @@ -43,9 +53,18 @@ def calc(ctx, command, files, dtype): with rasterio.open(files[0]) as first: kwargs = first.meta kwargs['transform'] = kwargs.pop('affine') + kwargs['dtype'] = dtype - sources = [rasterio.open(path).read() for path in files] - + # Using the class method instead of instance method. + # Latter raises + # TypeError: astype() got an unexpected keyword argument 'copy' + # Possibly something to do with the instance being a masked + # array. + sources = np.ma.asanyarray([np.ndarray.astype( + rasterio.open(path).read(), + 'float64', + copy=False + ) for path in files]) # TODO: implement a real parser for calc expressions, # perhaps using numexpr's parser as a guide, instead @@ -63,27 +82,16 @@ def calc(ctx, command, files, dtype): logger.debug("Translated cmd: %r", cmd) results = eval(cmd) - - # Using the class method instead of instance method. - # Latter raises - # TypeError: astype() got an unexpected keyword argument 'copy' - # Possibly something to do with the instance being a masked - # array. - results = np.ndarray.astype( - results, dtype or 'float64', copy=False) + results = np.ndarray.astype(results, dtype, copy=False) # Write results. if len(results.shape) == 3: - kwargs.update( - count=results.shape[0], - dtype=results.dtype.type) + kwargs['count'] = results.shape[0] with rasterio.open(output, 'w', **kwargs) as dst: dst.write(results) elif len(results.shape) == 2: - kwargs.update( - count=1, - dtype=results.dtype.type) + kwargs['count'] = 1 with rasterio.open(output, 'w', **kwargs) as dst: dst.write(results, 1) @@ -92,26 +100,22 @@ def calc(ctx, command, files, dtype): kwargs['count'] = len(parts) results = [] - #with rasterio.open(output, 'w', **kwargs) as dst: for part in parts: cmd = re.sub( r'{(\d)\s*,\s*(\d)}', - lambda m: 'sources[%d][%d]' % ( + lambda m: 'sources[%d,%d]' % ( int(m.group(1))-1, int(m.group(2))-1), part) logger.debug("Translated cmd: %r", cmd) res = eval(cmd) - res = np.ndarray.astype( - res, dtype or 'float64', copy=False) + res = np.ndarray.astype(res, dtype, copy=False) results.append(res) results = np.asanyarray(results) - kwargs.update( - count=results.shape[0], - dtype=results.dtype.type) + kwargs['count'] = results.shape[0] with rasterio.open(output, 'w', **kwargs) as dst: dst.write(results) diff --git a/tests/test_rio_calc.py b/tests/test_rio_calc.py index cfb16788..a535eb08 100644 --- a/tests/test_rio_calc.py +++ b/tests/test_rio_calc.py @@ -44,7 +44,7 @@ def test_parts_calc(tmpdir): result = runner.invoke(calc, [ '{1,1} + 125; {1,1}; {1,1}', '--dtype', 'uint8', - 'tests/data/shade.tif', + 'tests/data/shade.tif', outfile], catch_exceptions=False) assert result.exit_code == 0 @@ -55,3 +55,21 @@ def test_parts_calc(tmpdir): assert data[0].min() == 125 assert data[1].min() == 0 assert data[2].min() == 0 + + +def test_parts_calc_2(tmpdir): + # Produce greyscale output from the RGB file. + outfile = str(tmpdir.join('out.tif')) + runner = CliRunner() + result = runner.invoke(calc, [ + '({1,1} + {1,2} + {1,3})/3;', + '--dtype', 'uint8', + 'tests/data/RGB.byte.tif', + outfile], + catch_exceptions=False) + assert result.exit_code == 0 + with rasterio.open(outfile) as src: + assert src.count == 1 + assert src.meta['dtype'] == 'uint8' + data = src.read() + assert round(data.mean(), 1) == 60.3