Straighten out warped VRT nodata and masks

This commit is contained in:
Sean C. Gillies 2018-06-03 15:32:37 -06:00
parent 112004eeb0
commit b55c6faaaa
7 changed files with 178 additions and 71 deletions

1
.gitignore vendored
View File

@ -96,3 +96,4 @@ ignore/
MANIFEST
.ipynb_checkpoints
.pytest_cache
*.ipynb

View File

@ -1490,7 +1490,7 @@ cdef class DatasetWriterBase(DatasetReaderBase):
GDALSetColorEntry(hTable, i, &color)
# TODO: other color interpretations?
GDALSetRasterColorInterpretation(hBand, 1)
GDALSetRasterColorInterpretation(hBand, <GDALColorInterp>1)
GDALSetRasterColorTable(hBand, hTable)
GDALDestroyColorTable(hTable)

View File

@ -15,7 +15,7 @@ from rasterio._err import (
CPLE_AppDefinedError, CPLE_OpenFailedError)
from rasterio import dtypes
from rasterio.control import GroundControlPoint
from rasterio.enums import Resampling, MaskFlags
from rasterio.enums import Resampling, MaskFlags, ColorInterp
from rasterio.crs import CRS
from rasterio.errors import DriverRegistrationError, CRSError, RasterioIOError, RasterioDeprecationWarning
from rasterio.transform import Affine, from_bounds, guard_transform, tastes_like_gdal
@ -27,6 +27,7 @@ from rasterio._err cimport exc_wrap_pointer, exc_wrap_int
from rasterio._io cimport (
DatasetReaderBase, InMemoryRaster, in_dtype_range, io_auto)
from rasterio._features cimport GeomBuilder, OGRGeomBuilder
from rasterio._shim cimport delete_nodata_value
log = logging.getLogger(__name__)
@ -100,14 +101,14 @@ def _transform_geom(
cdef GDALWarpOptions * create_warp_options(
GDALResampleAlg resampling, object src_nodata, object dst_nodata,
int src_count, const char **options) except NULL:
int src_count, object dst_alpha, const char **options) except NULL:
"""Return a pointer to a GDALWarpOptions composed from input params
"""
# First, we make sure we have consistent source and destination
# nodata values. TODO: alpha bands.
if dst_nodata is None:
if dst_nodata is None and not dst_alpha:
if src_nodata is not None:
dst_nodata = src_nodata
else:
@ -143,7 +144,6 @@ cdef GDALWarpOptions * create_warp_options(
psWOptions.padfSrcNoDataReal[i] = float(src_nodata)
psWOptions.padfSrcNoDataImag[i] = 0.0
if dst_nodata is not None:
psWOptions.padfDstNoDataReal = <double*>CPLMalloc(src_count * sizeof(double))
psWOptions.padfDstNoDataImag = <double*>CPLMalloc(src_count * sizeof(double))
@ -152,6 +152,9 @@ cdef GDALWarpOptions * create_warp_options(
psWOptions.padfDstNoDataReal[i] = float(dst_nodata)
psWOptions.padfDstNoDataImag[i] = 0.0
if dst_alpha:
psWOptions.nDstAlphaBand = src_count + 1
# Important: set back into struct or values set above are lost
# This is because CSLSetNameValue returns a new list each time
psWOptions.papszWarpOptions = warp_extras
@ -179,6 +182,7 @@ def _reproject(
dst_transform=None,
dst_crs=None,
dst_nodata=None,
dst_alpha=False,
resampling=Resampling.nearest,
init_dest_nodata=True,
num_threads=1,
@ -518,7 +522,7 @@ def _reproject(
psWOptions = create_warp_options(
<GDALResampleAlg>resampling, src_nodata,
dst_nodata, src_count, <const char **>warp_extras)
dst_nodata, src_count, dst_alpha, <const char **>warp_extras)
psWOptions.pfnTransformer = pfnTransformer
psWOptions.pTransformerArg = hTransformArg
@ -640,12 +644,60 @@ def _calculate_default_transform(src_crs, dst_crs, width, height,
cdef class WarpedVRTReaderBase(DatasetReaderBase):
def __init__(self, src_dataset, src_crs=None, dst_crs=None,
def __init__(self, src_dataset, src_crs=None, dst_crs=None, crs=None,
resampling=Resampling.nearest, tolerance=0.125,
src_nodata=None, dst_nodata=None, dst_width=None,
dst_height=None, src_transform=None, dst_transform=None,
init_dest_nodata=True, **warp_extras):
src_nodata=None, dst_nodata=None, nodata=None,
dst_width=None, width=None, dst_height=None, height=None,
src_transform=None, dst_transform=None, transform=None,
init_dest_nodata=True, add_alpha=False, **warp_extras):
"""Make a virtual warped dataset
Parameters
----------
src_dataset : dataset object
The warp source.
src_crs : CRS or str, optional
Overrides the coordinate reference system of `src_dataset`.
src_transfrom : Affine, optional
Overrides the transform of `src_dataset`.
src_nodata : float, optional
Overrides the nodata value of `src_dataset`.
crs : CRS or str, optional
The coordinate reference system at the end of the warp
operation. Default: the crs of `src_dataset`. dst_crs is
a deprecated alias for this parameter.
transform : Affine, optional
The transform for the virtual dataset. Default: will be
computed from the attributes of `src_dataset`. dst_transform
is a deprecated alias for this parameter.
height, width: int, optional
The dimensions of the virtual dataset. Defaults: will be
computed from the attributes of `src_dataset`. dst_height
and dst_width are deprecated alias for these parameters.
nodata : float, optional
Nodata value for the virtual dataset. Default: the nodata
value of `src_dataset` or 0.0. dst_nodata is a deprecated
alias for this parameter.
resampling : Resampling, optional
Warp resampling algorithm. Default: `Resampling.nearest`.
tolerance : float, optional
The maximum error tolerance in input pixels when
approximating the warp transformation. Default: 0.125,
or one-eigth of a pixel.
add_alpha : bool, optional
Whether to add an alpha masking band to the virtual dataset.
Default: False. This option will cause deletion of the VRT
nodata value.
init_dest_nodata : bool, optional
Whether or not to initialize output to `nodata`. Default: True.
warp_extras : dict
GDAL extra warp options. See
http://www.gdal.org/structGDALWarpOptions.html.
Returns
-------
WarpedVRT
"""
self.mode = 'r'
self.options = {}
self._count = 0
@ -659,24 +711,71 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
self._gcps = None
self._read = False
# Deprecation of "dst" parameters.
if dst_nodata is not None:
warnings.warn(
"dst_nodata will be removed after 1.0, use nodata",
RasterioDeprecationWarning)
if nodata is None:
nodata = dst_nodata
if dst_width is not None:
warnings.warn(
"dst_width will be removed after 1.0, use width",
RasterioDeprecationWarning)
if width is None:
width = dst_width
if dst_height is not None:
warnings.warn(
"dst_height will be removed after 1.0, use height",
RasterioDeprecationWarning)
if height is None:
height = dst_height
if dst_transform is not None:
warnings.warn(
"dst_transform will be removed after 1.0, use transform",
RasterioDeprecationWarning)
if transform is None:
transform = dst_transform
if dst_crs is not None:
warnings.warn(
"dst_crs will be removed after 1.0, use crs",
RasterioDeprecationWarning)
if crs is None:
crs = dst_crs if dst_crs is not None else src_dataset.crs
# kwargs become warp options.
self.src_dataset = src_dataset
self.src_crs = src_crs
self.src_crs = CRS.from_user_input(src_crs) if src_crs else None
self.dst_crs = CRS.from_user_input(crs) if crs else None
self.src_transform = src_transform
self.name = "WarpedVRT({})".format(src_dataset.name)
self.dst_crs = CRS.from_user_input(dst_crs)
self.resampling = resampling
self.tolerance = tolerance
self.src_nodata = self.src_dataset.nodata if src_nodata is None else src_nodata
self.dst_nodata = self.src_nodata if dst_nodata is None else dst_nodata
self.dst_width = dst_width
self.dst_height = dst_height
self.dst_transform = dst_transform
self.dst_nodata = self.src_nodata if nodata is None else nodata
self.dst_width = width
self.dst_height = height
self.dst_transform = transform
self.warp_extras = warp_extras.copy()
if init_dest_nodata is True and 'init_dest' not in warp_extras:
self.warp_extras['init_dest'] = 'NO_DATA'
# If we're adding an alpha band, set the nodata value to None
# so it doesn't shadow our alpha band.
self.dst_alpha = add_alpha
if self.dst_alpha:
self.dst_nodata = None
cdef GDALDriverH driver = NULL
cdef GDALDatasetH hds = NULL
cdef GDALDatasetH hds_warped = NULL
@ -688,8 +787,8 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
cdef GDALWarpOptions *psWOptions = NULL
cdef float c_tolerance = tolerance
cdef GDALResampleAlg c_resampling = resampling
cdef int c_width = dst_width or 0
cdef int c_height = dst_height or 0
cdef int c_width = self.dst_width or 0
cdef int c_height = self.dst_height or 0
cdef double src_gt[6]
cdef double dst_gt[6]
cdef void *hTransformArg = NULL
@ -734,18 +833,6 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
log.debug("Exported CRS to WKT.")
# Flag if the source dataset has a dataset mask, and
# get the block size. This code is adapted from GDAL's
# VRT builder.
if MaskFlags.per_dataset in src_dataset.mask_flag_enums:
has_dataset_mask = True
else:
has_dataset_mask = False
band = GDALGetRasterBand(hds, 1)
hmask = GDALGetMaskBand(hband)
GDALGetBlockSize(hmask, &mask_block_xsize, &mask_block_ysize)
log.debug("Warp_extras: %r", self.warp_extras)
for key, val in self.warp_extras.items():
@ -756,11 +843,11 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
psWOptions = create_warp_options(
<GDALResampleAlg>c_resampling, self.src_nodata,
self.dst_nodata,
GDALGetRasterCount(hds), <const char **>c_warp_extras)
self.dst_nodata, GDALGetRasterCount(hds), self.dst_alpha,
<const char **>c_warp_extras)
try:
if dst_width and dst_height and dst_transform:
if self.dst_width and self.dst_height and self.dst_transform:
# set up transform args (otherwise handled in
# GDALAutoCreateWarpedVRT)
try:
@ -799,11 +886,6 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
c_tolerance, psWOptions)
self._hds = exc_wrap_pointer(hds_warped)
# Add the mask band if appropriate.
if has_dataset_mask:
GDALCreateDatasetMaskBand(hds_warped, MaskFlags.per_dataset.value)
hmask = GDALGetMaskBand(GDALGetRasterBand(hds_warped, 1))
except CPLE_OpenFailedError as err:
raise RasterioIOError(err.errmsg)
finally:
@ -811,11 +893,21 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
CSLDestroy(c_warp_extras)
GDALDestroyWarpOptions(psWOptions)
if self.dst_nodata is None:
for i in self.indexes:
delete_nodata_value(self.band(i))
else:
for i in self.indexes:
GDALSetRasterNoDataValue(self.band(i), self.dst_nodata)
if self.dst_alpha:
GDALSetRasterColorInterpretation(self.band(4), <GDALColorInterp>6)
self._set_attrs_from_dataset_handle()
# This attribute will be used by read().
self._nodatavals = [
self.src_nodata for i in self.src_dataset.indexes]
self.dst_nodata for i in self.src_dataset.indexes]
def get_crs(self):
warnings.warn("get_crs() will be removed in 1.0", RasterioDeprecationWarning)

View File

@ -146,6 +146,23 @@ cdef extern from "gdal.h" nogil:
GRIORA_Mode
GRIORA_Gauss
ctypedef enum GDALColorInterp:
GCI_Undefined
GCI_GrayIndex
GCI_PaletteIndex
GCI_RedBand
GCI_GreenBand
GCI_BlueBand
GCI_AlphaBand
GCI_HueBand
GCI_SaturationBand
GCI_LightnessBand
GCI_CyanBand
GCI_YCbCr_YBand
GCI_YCbCr_CbBand
GCI_YCbCr_CrBand
GCI_Max
ctypedef struct GDALColorEntry:
short c1
short c2
@ -222,7 +239,7 @@ cdef extern from "gdal.h" nogil:
void GDALDestroyColorTable(GDALColorTableH table)
int GDALGetColorEntryCount(GDALColorTableH table)
int GDALGetRasterColorInterpretation(GDALRasterBandH band)
int GDALSetRasterColorInterpretation(GDALRasterBandH band, int)
int GDALSetRasterColorInterpretation(GDALRasterBandH band, GDALColorInterp)
int GDALGetMaskFlags(GDALRasterBandH band)
int GDALCreateDatasetMaskBand(GDALDatasetH hds, int flags)
void *GDALGetMaskBand(GDALRasterBandH band)

View File

@ -7,7 +7,7 @@ from rasterio.transform import TransformMethodsMixin
class WarpedVRT(WarpedVRTReaderBase, WindowMethodsMixin,
TransformMethodsMixin):
"""Creates a virtual warped dataset.
"""A virtual warped dataset.
Abstracts the details of raster warping and allows access to data
that is reprojected when read.
@ -16,11 +16,8 @@ class WarpedVRT(WarpedVRTReaderBase, WindowMethodsMixin,
Attributes
----------
src_dataset : dataset
The dataset object to be virtually warped.
dst_crs : CRS or str
The warp operation's destination coordinate reference system.
resampling : int
One of the values from rasterio.enums.Resampling. The default is
`Resampling.nearest`.
@ -31,23 +28,10 @@ class WarpedVRT(WarpedVRTReaderBase, WindowMethodsMixin,
The source nodata value. Pixels with this value will not be
used for interpolation. If not set, it will be default to the
nodata value of the source image, if available.
src_crs : CRS or str, optional
Source image CRS to set or overwrite
src_transform : affine.Affine(), optional
Source image affine transform to set or overwrite
dst_nodata: int or float, optional
The nodata value used to initialize the destination; it will
remain in all areas not covered by the reprojected source.
Defaults to the value of src_nodata, or 0 (gdal default).
dst_width : int, optional
Target width in pixels. dst_height and dst_transform must also be
provided.
dst_height : int, optional
Target height in pixels. dst_width and dst_transform must also be
provided.
dst_transform: affine.Affine(), optional
Target affine transformation. Required if width and height are
provided.
warp_extras : dict
GDAL extra warp options. See
http://www.gdal.org/structGDALWarpOptions.html.
@ -56,7 +40,7 @@ class WarpedVRT(WarpedVRTReaderBase, WindowMethodsMixin,
--------
>>> with rasterio.open('tests/data/RGB.byte.tif') as src:
... with WarpedVRT(src, dst_crs='EPSG:3857') as vrt:
... with WarpedVRT(src, crs='EPSG:3857') as vrt:
... data = vrt.read()
"""

View File

@ -276,7 +276,7 @@ def reproject(source, destination, src_transform=None, gcps=None,
# Call the function in our extension module.
_reproject(source, destination, src_transform, gcps, src_crs, src_nodata,
dst_transform, dst_crs, dst_nodata, resampling,
dst_transform, dst_crs, dst_nodata, False, resampling,
init_dest_nodata, **kwargs)

View File

@ -9,7 +9,7 @@ import pytest
import rasterio
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.enums import Resampling, MaskFlags
from rasterio import shutil as rio_shutil
from rasterio.vrt import WarpedVRT
from rasterio.warp import transform_bounds
@ -38,19 +38,34 @@ def _copy_update_profile(path_in, path_out, **kwargs):
def test_warped_vrt(path_rgb_byte_tif):
"""A VirtualVRT has the expected VRT properties."""
with rasterio.open(path_rgb_byte_tif) as src:
vrt = WarpedVRT(src, dst_crs=DST_CRS)
vrt = WarpedVRT(src, crs=DST_CRS)
assert vrt.dst_crs == CRS.from_string(DST_CRS)
assert vrt.src_nodata == 0.0
assert vrt.dst_nodata == 0.0
assert vrt.tolerance == 0.125
assert vrt.resampling == Resampling.nearest
assert vrt.warp_extras == {'init_dest': 'NO_DATA'}
assert vrt.mask_flag_enums == ([MaskFlags.nodata], ) * 3
def test_warped_vrt_dst_alpha(path_rgb_byte_tif):
"""A VirtualVRT has the expected VRT properties."""
with rasterio.open(path_rgb_byte_tif) as src:
vrt = WarpedVRT(src, crs=DST_CRS, add_alpha=True)
assert vrt.dst_crs == CRS.from_string(DST_CRS)
assert vrt.src_nodata == 0.0
assert vrt.dst_nodata is None
assert vrt.tolerance == 0.125
assert vrt.resampling == Resampling.nearest
assert vrt.warp_extras == {'init_dest': 'NO_DATA'}
assert vrt.count == 4
assert vrt.mask_flag_enums == ([MaskFlags.per_dataset, MaskFlags.alpha], ) * 3 + ([MaskFlags.all_valid], )
def test_warped_vrt_source(path_rgb_byte_tif):
"""A VirtualVRT has the expected source dataset."""
with rasterio.open(path_rgb_byte_tif) as src:
vrt = WarpedVRT(src, dst_crs=DST_CRS)
vrt = WarpedVRT(src, crs=DST_CRS)
assert vrt.src_dataset == src
@ -62,16 +77,16 @@ def test_warped_vrt_set_src_crs(path_rgb_byte_tif, tmpdir):
original_crs = src.crs
with rasterio.open(path_crs_unset) as src:
with pytest.raises(Exception):
with WarpedVRT(src, dst_crs=DST_CRS) as vrt:
with WarpedVRT(src, crs=DST_CRS) as vrt:
pass
with WarpedVRT(src, src_crs=original_crs, dst_crs=DST_CRS) as vrt:
with WarpedVRT(src, src_crs=original_crs, crs=DST_CRS) as vrt:
assert vrt.src_crs == original_crs
def test_wrap_file(path_rgb_byte_tif):
"""A VirtualVRT has the expected dataset properties."""
with rasterio.open(path_rgb_byte_tif) as src:
vrt = WarpedVRT(src, dst_crs=DST_CRS)
vrt = WarpedVRT(src, crs=DST_CRS)
assert vrt.crs == CRS.from_string(DST_CRS)
assert tuple(round(x, 1) for x in vrt.bounds) == (
-8789636.7, 2700460.0, -8524406.4, 2943560.2)
@ -95,9 +110,7 @@ def test_warped_vrt_dimensions(path_rgb_byte_tif):
dst_transform = affine.Affine(
resolution, 0.0, extent[0],
0.0, -resolution, extent[1])
vrt = WarpedVRT(src, dst_crs=DST_CRS,
dst_width=size, dst_height=size,
dst_transform=dst_transform)
vrt = WarpedVRT(src, crs=DST_CRS, width=size, height=size, transform=dst_transform)
assert vrt.dst_crs == CRS.from_string(DST_CRS)
assert vrt.src_nodata == 0.0
assert vrt.dst_nodata == 0.0
@ -111,7 +124,7 @@ def test_warped_vrt_dimensions(path_rgb_byte_tif):
def test_warp_extras(path_rgb_byte_tif):
"""INIT_DEST warp extra is passed through."""
with rasterio.open(path_rgb_byte_tif) as src:
with WarpedVRT(src, dst_crs=DST_CRS, init_dest=255) as vrt:
with WarpedVRT(src, crs=DST_CRS, init_dest=255) as vrt:
rgb = vrt.read()
assert (rgb[:, 0, 0] == 255).all()
@ -123,7 +136,7 @@ def test_wrap_s3():
"""A warp wrapper's dataset has the expected properties"""
L8TIF = "s3://landsat-pds/L8/139/045/LC81390452014295LGN00/LC81390452014295LGN00_B1.TIF"
with rasterio.open(L8TIF) as src:
with WarpedVRT(src, dst_crs=DST_CRS, src_nodata=0, dst_nodata=0) as vrt:
with WarpedVRT(src, crs=DST_CRS, src_nodata=0, nodata=0) as vrt:
assert vrt.crs == DST_CRS
assert tuple(round(x, 1) for x in vrt.bounds) == (
9556764.6, 2345109.3, 9804595.9, 2598509.1)
@ -137,7 +150,7 @@ def test_wrap_s3():
def test_warped_vrt_nodata_read(path_rgb_byte_tif):
"""A read from a VirtualVRT respects dst_nodata."""
with rasterio.open(path_rgb_byte_tif) as src:
with WarpedVRT(src, dst_crs=DST_CRS, src_nodata=0) as vrt:
with WarpedVRT(src, crs=DST_CRS, src_nodata=0) as vrt:
data = vrt.read(1, masked=True)
assert data.mask.any()
mask = vrt.dataset_mask()