Preserve masks when re-casting to float64.

Also test reducing a 3-band file to 1.
This commit is contained in:
Sean Gillies 2015-02-08 10:57:39 -07:00
parent d8f3d9da30
commit 66ae338d97
2 changed files with 52 additions and 30 deletions

View File

@ -18,17 +18,27 @@ from rasterio.rio.cli import cli
@click.argument('command')
@files_inout_arg
@click.option('--dtype',
type=click.Choice(
['uint8', 'uint16', 'int16', 'float32', 'float64']),
help="Output data type.")
type=click.Choice([
'ubyte', 'uint8', 'uint16', 'int16', 'uint32',
'int32', 'float32', 'float64']),
default='float64',
help="Output data type (default: float64).")
@click.pass_context
def calc(ctx, command, files, dtype):
"""A raster data calculator
Applies one or more commands to a set of input datasets and writes the
results to a new dataset.
Applies one or more commands to a set of input datasets and writes
the results to a new dataset.
Command syntax is a work in progress.
Command syntax is a work in progress. Currently:
* {n} represents the n-th input dataset (a 3-D array)
* {n,m} represents the m-th band of the n-th dataset (a 2-D array).
*
Calculations on all bands of a
dataset can be specified using
"""
import numpy as np
@ -43,9 +53,18 @@ def calc(ctx, command, files, dtype):
with rasterio.open(files[0]) as first:
kwargs = first.meta
kwargs['transform'] = kwargs.pop('affine')
kwargs['dtype'] = dtype
sources = [rasterio.open(path).read() for path in files]
# Using the class method instead of instance method.
# Latter raises
# TypeError: astype() got an unexpected keyword argument 'copy'
# Possibly something to do with the instance being a masked
# array.
sources = np.ma.asanyarray([np.ndarray.astype(
rasterio.open(path).read(),
'float64',
copy=False
) for path in files])
# TODO: implement a real parser for calc expressions,
# perhaps using numexpr's parser as a guide, instead
@ -63,27 +82,16 @@ def calc(ctx, command, files, dtype):
logger.debug("Translated cmd: %r", cmd)
results = eval(cmd)
# Using the class method instead of instance method.
# Latter raises
# TypeError: astype() got an unexpected keyword argument 'copy'
# Possibly something to do with the instance being a masked
# array.
results = np.ndarray.astype(
results, dtype or 'float64', copy=False)
results = np.ndarray.astype(results, dtype, copy=False)
# Write results.
if len(results.shape) == 3:
kwargs.update(
count=results.shape[0],
dtype=results.dtype.type)
kwargs['count'] = results.shape[0]
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)
elif len(results.shape) == 2:
kwargs.update(
count=1,
dtype=results.dtype.type)
kwargs['count'] = 1
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results, 1)
@ -92,26 +100,22 @@ def calc(ctx, command, files, dtype):
kwargs['count'] = len(parts)
results = []
#with rasterio.open(output, 'w', **kwargs) as dst:
for part in parts:
cmd = re.sub(
r'{(\d)\s*,\s*(\d)}',
lambda m: 'sources[%d][%d]' % (
lambda m: 'sources[%d,%d]' % (
int(m.group(1))-1, int(m.group(2))-1),
part)
logger.debug("Translated cmd: %r", cmd)
res = eval(cmd)
res = np.ndarray.astype(
res, dtype or 'float64', copy=False)
res = np.ndarray.astype(res, dtype, copy=False)
results.append(res)
results = np.asanyarray(results)
kwargs.update(
count=results.shape[0],
dtype=results.dtype.type)
kwargs['count'] = results.shape[0]
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)

View File

@ -44,7 +44,7 @@ def test_parts_calc(tmpdir):
result = runner.invoke(calc, [
'{1,1} + 125; {1,1}; {1,1}',
'--dtype', 'uint8',
'tests/data/shade.tif',
'tests/data/shade.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
@ -55,3 +55,21 @@ def test_parts_calc(tmpdir):
assert data[0].min() == 125
assert data[1].min() == 0
assert data[2].min() == 0
def test_parts_calc_2(tmpdir):
# Produce greyscale output from the RGB file.
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'({1,1} + {1,2} + {1,3})/3;',
'--dtype', 'uint8',
'tests/data/RGB.byte.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.count == 1
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert round(data.mean(), 1) == 60.3