PR 2301 and follow up (#2318)

* Add in memory raster that subclasses DatasetBase.

* Remove unused variables.

* Add r+ to modes setting georeferencing.

* Fix dtype argument.

* Use InMemoryRasterArray

* Use InMemoryRasterArray in warp.

* Eliminate unnecessary copy.

* Add missing word.

* Use InMemoryRasterArray in fillnodata

* Add array interface method.

* Resolve fillnodata test failure.

* Remove unnecessary copy.

* Cleaner array handling.

Make sure _array is always an array, but only copy when needed/wanted.

* Rename InMemoryRaster to MemoryDataset.

* Add internal use only comment to MemoryDataset.

* Follow ups on #2301

* Fix parameter type in docstring and whitespace

Co-authored-by: Ryan Grout <ryan@ryangrout.org>
This commit is contained in:
Sean Gillies 2021-10-19 20:09:04 -06:00 committed by GitHub
parent fa4b5ae804
commit ee49f462f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 94 additions and 234 deletions

View File

@ -9,6 +9,8 @@ Changes
New features:
- The InMemoryRaster class in rasterio._io has been removed and replaced by a
more direct and efficient wrapper around numpy arrays (#2301).
- Add support for PROJJSON based interchange for CRS (#2212).
CRS.to_dict(projjson=True) returns a PROJJSON style dict and CRS.from_dict()
will accept a PROJJSON style dict. PROJJSON text is accepted by

View File

@ -412,7 +412,7 @@ cdef class DatasetBase:
if err == GDALError.failure and not self._has_gcps_or_rpcs():
warnings.warn(
("Dataset has no geotransform, gcps, or rpcs. "
"The identity matrix be returned."),
"The identity matrix will be returned."),
NotGeoreferencedWarning)
return [gt[i] for i in range(6)]

View File

@ -9,7 +9,7 @@ from rasterio.dtypes import _getnpdtype
from rasterio.enums import MergeAlg
from rasterio._err cimport exc_wrap_int, exc_wrap_pointer
from rasterio._io cimport DatasetReaderBase, InMemoryRaster, io_auto
from rasterio._io cimport DatasetReaderBase, MemoryDataset, io_auto
log = logging.getLogger(__name__)
@ -54,8 +54,8 @@ def _shapes(image, mask, connectivity, transform):
cdef OGRLayerH layer = NULL
cdef OGRFieldDefnH fielddefn = NULL
cdef char **options = NULL
cdef InMemoryRaster mem_ds = None
cdef InMemoryRaster mask_ds = None
cdef MemoryDataset mem_ds = None
cdef MemoryDataset mask_ds = None
cdef ShapeIterator shape_iter = None
cdef int fieldtp
@ -74,7 +74,7 @@ def _shapes(image, mask, connectivity, transform):
try:
if dtypes.is_ndarray(image):
mem_ds = InMemoryRaster(image=image, transform=transform)
mem_ds = MemoryDataset(image, transform=transform)
band = mem_ds.band(1)
elif isinstance(image, tuple):
rdr = image.ds
@ -92,7 +92,7 @@ def _shapes(image, mask, connectivity, transform):
if dtypes.is_ndarray(mask):
# A boolean mask must be converted to uint8 for GDAL
mask_ds = InMemoryRaster(image=mask.astype('uint8'),
mask_ds = MemoryDataset(mask.astype('uint8'),
transform=transform)
maskband = mask_ds.band(1)
elif isinstance(mask, tuple):
@ -172,9 +172,9 @@ def _sieve(image, size, out, mask, connectivity):
cdef int retval
cdef int rows
cdef int cols
cdef InMemoryRaster in_mem_ds = None
cdef InMemoryRaster out_mem_ds = None
cdef InMemoryRaster mask_mem_ds = None
cdef MemoryDataset in_mem_ds = None
cdef MemoryDataset out_mem_ds = None
cdef MemoryDataset mask_mem_ds = None
cdef GDALRasterBandH in_band = NULL
cdef GDALRasterBandH out_band = NULL
cdef GDALRasterBandH mask_band = NULL
@ -206,7 +206,7 @@ def _sieve(image, size, out, mask, connectivity):
try:
if dtypes.is_ndarray(image):
in_mem_ds = InMemoryRaster(image=image)
in_mem_ds = MemoryDataset(image)
in_band = in_mem_ds.band(1)
elif isinstance(image, tuple):
rdr = image.ds
@ -216,7 +216,7 @@ def _sieve(image, size, out, mask, connectivity):
if dtypes.is_ndarray(out):
log.debug("out array: %r", out)
out_mem_ds = InMemoryRaster(image=out)
out_mem_ds = MemoryDataset(out)
out_band = out_mem_ds.band(1)
elif isinstance(out, tuple):
udr = out.ds
@ -234,7 +234,7 @@ def _sieve(image, size, out, mask, connectivity):
if dtypes.is_ndarray(mask):
# A boolean mask must be converted to uint8 for GDAL
mask_mem_ds = InMemoryRaster(image=mask.astype('uint8'))
mask_mem_ds = MemoryDataset(mask.astype('uint8'))
mask_band = mask_mem_ds.band(1)
elif isinstance(mask, tuple):
@ -319,7 +319,8 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg):
cdef OGRGeometryH *geoms = NULL
cdef char **options = NULL
cdef double *pixel_values = NULL
cdef InMemoryRaster mem = None
cdef MemoryDataset mem = None
cdef int *band_ids = NULL
try:
if all_touched:
@ -346,20 +347,21 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg):
geometry, i, value)
# TODO: is a vsimem file more memory efficient?
with InMemoryRaster(image=image, transform=transform) as mem:
with MemoryDataset(image, transform=transform) as mem:
band_ids = <int *>CPLMalloc(mem.count*sizeof(int))
for i in range(mem.count):
band_ids[i] = i + 1
exc_wrap_int(
GDALRasterizeGeometries(
mem.handle(), 1, mem.band_ids, num_geoms, geoms, NULL,
mem.handle(), 1, band_ids, num_geoms, geoms, NULL,
NULL, pixel_values, options, NULL, NULL))
# Read in-memory data back into image
image = mem.read()
finally:
for i in range(num_geoms):
_deleteOgrGeom(geoms[i])
CPLFree(geoms)
CPLFree(pixel_values)
CPLFree(band_ids)
if options:
CSLDestroy(options)

View File

@ -4,8 +4,9 @@
include "gdal.pxi"
import numpy as np
from rasterio._err cimport exc_wrap_int
from rasterio._io cimport InMemoryRaster
from rasterio._io cimport MemoryDataset
def _fillnodata(image, mask, double max_search_distance=100.0,
@ -13,24 +14,24 @@ def _fillnodata(image, mask, double max_search_distance=100.0,
cdef GDALRasterBandH image_band = NULL
cdef GDALRasterBandH mask_band = NULL
cdef char **alg_options = NULL
cdef InMemoryRaster image_dataset = None
cdef InMemoryRaster mask_dataset = None
cdef MemoryDataset image_dataset = None
cdef MemoryDataset mask_dataset = None
try:
# copy numpy ndarray into an in-memory dataset.
image_dataset = InMemoryRaster(image)
image_dataset = MemoryDataset(image)
image_band = image_dataset.band(1)
if mask is not None:
mask_cast = mask.astype('uint8')
mask_dataset = InMemoryRaster(mask_cast)
mask_dataset = MemoryDataset(mask_cast)
mask_band = mask_dataset.band(1)
alg_options = CSLSetNameValue(alg_options, "TEMP_FILE_DRIVER", "MEM")
exc_wrap_int(
GDALFillNodata(image_band, mask_band, max_search_distance, 0,
smoothing_iterations, alg_options, NULL, NULL))
return image_dataset.read()
return np.asarray(image_dataset)
finally:
if image_dataset is not None:
image_dataset.close()

View File

@ -21,16 +21,8 @@ cdef class BufferedDatasetWriterBase(DatasetWriterBase):
pass
cdef class InMemoryRaster:
cdef GDALDatasetH _hds
cdef double gdal_transform[6]
cdef int* band_ids
cdef np.ndarray _image
cdef object crs
cdef object transform # this is an Affine object.
cdef GDALDatasetH handle(self) except NULL
cdef GDALRasterBandH band(self, int) except NULL
cdef class MemoryDataset(DatasetWriterBase):
cdef np.ndarray _array
cdef class MemoryFileBase:

View File

@ -25,7 +25,7 @@ from rasterio.errors import (
NotGeoreferencedWarning, NodataShadowWarning, WindowError,
UnsupportedOperation, OverviewCreationError, RasterBlockError, InvalidArrayError
)
from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype
from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype, _gdal_typename
from rasterio.sample import sample_gen
from rasterio.transform import Affine
from rasterio.path import parse_path, UnparsedPath
@ -1904,208 +1904,70 @@ cdef class DatasetWriterBase(DatasetReaderBase):
self.update_tags(ns='RPC', **rpcs)
self._rpcs = None
cdef class InMemoryRaster:
"""
Class that manages a single-band in memory GDAL raster dataset. Data type
is determined from the data type of the input numpy 2D array (image), and
must be one of the data types supported by GDAL
(see rasterio.dtypes.dtype_rev). Data are populated at create time from
the 2D array passed in.
Use the 'with' pattern to instantiate this class for automatic closing
of the memory dataset.
cdef class MemoryDataset(DatasetWriterBase):
def __init__(self, arr, transform=None, gcps=None, rpcs=None, crs=None, copy=False):
"""Dataset wrapped around in-memory array.
This class includes attributes that are intended to be passed into GDAL
functions:
self.dataset
self.band
self.band_ids (single element array with band ID of this dataset's band)
self.transform (GDAL compatible transform array)
This class is intended for internal use only within rasterio to
support IO with GDAL, where a Dataset object is needed.
This class is only intended for internal use within rasterio to support
IO with GDAL. Other memory based operations should use numpy arrays.
"""
def __cinit__(self):
self._hds = NULL
self.band_ids = NULL
self._image = None
self.crs = None
self.transform = None
MemoryDataset supports the NumPy array interface.
Parameters
----------
arr : ndarray
Array to use for dataset
transform : Transform
Dataset transform
gcps : list
List of GroundControlPoints, a CRS
rpcs : list
Dataset rational polynomial coefficients
crs : CRS
Dataset coordinate reference system
copy : bool, optional
Create an internal copy of the array. If set to False,
caller must make sure that arr is valid while this object
lives.
def __init__(self, image=None, dtype='uint8', count=1, width=None,
height=None, transform=None, gcps=None, rpcs=None, crs=None):
"""
Create in-memory raster dataset, and fill its bands with the
arrays in image.
self._array = np.array(arr, copy=copy)
dtype = self._array.dtype
An empty in-memory raster with no memory allocated to bands,
e.g. for use in _calculate_default_transform(), can be created
by passing dtype, count, width, and height instead.
if self._array.ndim == 2:
count = 1
height, width = arr.shape
elif self._array.ndim == 3:
count, height, width = arr.shape
else:
raise ValueError("arr must be 2D or 3D array")
:param image: 2D numpy array. Must be of supported data type
(see rasterio.dtypes.dtype_rev)
:param transform: Affine transform object
"""
cdef int i = 0 # avoids Cython warning in for loop below
cdef char *srcwkt = NULL
cdef OGRSpatialReferenceH osr = NULL
cdef GDALDriverH mdriver = NULL
cdef GDAL_GCP *gcplist = NULL
cdef char **options = NULL
cdef char **papszMD = NULL
arr_info = self._array.__array_interface__
info = {
"DATAPOINTER": arr_info["data"][0],
"PIXELS": width,
"LINES": height,
"BANDS": count,
"DATATYPE": _gdal_typename(arr.dtype.name)
}
dataset_options = ",".join(f"{name}={val}" for name, val in info.items())
datasetname = f"MEM:::{dataset_options}"
if image is not None:
if image.ndim == 3:
count, height, width = image.shape
elif image.ndim == 2:
count = 1
height, width = image.shape
dtype = image.dtype.name
with warnings.catch_warnings():
warnings.simplefilter("ignore")
super().__init__(parse_path(datasetname), "r+")
if crs is not None:
self.crs = crs
if transform is not None:
self.transform = transform
if gcps is not None and crs is not None:
self.gcps = (gcps, crs)
if rpcs is not None:
self.rpcs = rpcs
if height is None or height == 0:
raise ValueError("height must be > 0")
if width is None or width == 0:
raise ValueError("width must be > 0")
self.band_ids = <int *>CPLMalloc(count*sizeof(int))
for i in range(1, count + 1):
self.band_ids[i-1] = i
try:
memdriver = exc_wrap_pointer(GDALGetDriverByName("MEM"))
except Exception:
raise DriverRegistrationError(
"'MEM' driver not found. Check that this call is contained "
"in a `with rasterio.Env()` or `with rasterio.open()` "
"block.")
if _getnpdtype(dtype) == _getnpdtype("int8"):
options = CSLSetNameValue(options, 'PIXELTYPE', 'SIGNEDBYTE')
datasetname = str(uuid4()).encode('utf-8')
self._hds = exc_wrap_pointer(
GDALCreate(memdriver, <const char *>datasetname, width, height,
count, <GDALDataType>dtypes.dtype_rev[dtype], options))
if transform is not None:
self.transform = transform
gdal_transform = transform.to_gdal()
for i in range(6):
self.gdal_transform[i] = gdal_transform[i]
exc_wrap_int(GDALSetGeoTransform(self._hds, self.gdal_transform))
if crs:
osr = _osr_from_crs(crs)
try:
OSRExportToWkt(osr, &srcwkt)
exc_wrap_int(GDALSetProjection(self._hds, srcwkt))
log.debug("Set CRS on temp dataset: %s", srcwkt)
finally:
CPLFree(srcwkt)
_safe_osr_release(osr)
elif gcps and crs:
try:
gcplist = <GDAL_GCP *>CPLMalloc(len(gcps) * sizeof(GDAL_GCP))
for i, obj in enumerate(gcps):
ident = str(i).encode('utf-8')
info = "".encode('utf-8')
gcplist[i].pszId = ident
gcplist[i].pszInfo = info
gcplist[i].dfGCPPixel = obj.col
gcplist[i].dfGCPLine = obj.row
gcplist[i].dfGCPX = obj.x
gcplist[i].dfGCPY = obj.y
gcplist[i].dfGCPZ = obj.z or 0.0
osr = _osr_from_crs(crs)
OSRExportToWkt(osr, &srcwkt)
exc_wrap_int(GDALSetGCPs(self._hds, len(gcps), gcplist, srcwkt))
finally:
CPLFree(gcplist)
CPLFree(srcwkt)
_safe_osr_release(osr)
elif rpcs:
try:
if hasattr(rpcs, 'to_gdal'):
rpcs = rpcs.to_gdal()
for key, val in rpcs.items():
key = key.upper().encode('utf-8')
val = str(val).encode('utf-8')
papszMD = CSLSetNameValue(
papszMD, <const char *>key, <const char *>val)
exc_wrap_int(GDALSetMetadata(self._hds, papszMD, "RPC"))
finally:
CSLDestroy(papszMD)
if options != NULL:
CSLDestroy(options)
if image is not None:
self.write(image)
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.close()
def __dealloc__(self):
if self.band_ids != NULL:
CPLFree(self.band_ids)
self.band_ids = NULL
cdef GDALDatasetH handle(self) except NULL:
"""Return the object's GDAL dataset handle"""
return self._hds
cdef GDALRasterBandH band(self, int bidx) except NULL:
"""Return a GDAL raster band handle"""
cdef GDALRasterBandH band = NULL
try:
band = exc_wrap_pointer(GDALGetRasterBand(self._hds, bidx))
except CPLE_IllegalArgError as exc:
raise IndexError(str(exc))
# Don't get here.
if band == NULL:
raise ValueError("NULL band")
return band
def close(self):
if self._hds != NULL:
GDALClose(self._hds)
self._hds = NULL
def read(self):
if self._image is None:
raise RasterioIOError("You need to write data before you can read the data.")
try:
if self._image.ndim == 2:
io_auto(self._image, self.band(1), False)
else:
io_auto(self._image, self._hds, False)
except CPLE_BaseError as cplerr:
raise RasterioIOError("Read or write failed. {}".format(cplerr))
return self._image
def write(self, np.ndarray image):
self._image = image
try:
if image.ndim == 2:
io_auto(self._image, self.band(1), True)
else:
io_auto(self._image, self._hds, True)
except CPLE_BaseError as cplerr:
raise RasterioIOError("Read or write failed. {}".format(cplerr))
def __array__(self):
return self._array
cdef class BufferedDatasetWriterBase(DatasetWriterBase):

View File

@ -37,7 +37,7 @@ from libc.math cimport HUGE_VAL
from rasterio._base cimport _osr_from_crs, get_driver_name, _safe_osr_release
from rasterio._err cimport exc_wrap_pointer, exc_wrap_int
from rasterio._io cimport (
DatasetReaderBase, InMemoryRaster, in_dtype_range, io_auto)
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto)
from rasterio._features cimport GeomBuilder, OGRGeomBuilder
@ -360,8 +360,8 @@ def _reproject(
in_transform = in_transform.translation(eps, eps)
return in_transform
cdef InMemoryRaster mem_raster = None
cdef InMemoryRaster src_mem = None
cdef MemoryDataset mem_raster = None
cdef MemoryDataset src_mem = None
try:
@ -381,11 +381,12 @@ def _reproject(
source = source.reshape(1, *source.shape)
src_count = source.shape[0]
src_bidx = range(1, src_count + 1)
src_mem = InMemoryRaster(image=source,
src_mem = MemoryDataset(source,
transform=format_transform(src_transform),
gcps=gcps,
rpcs=rpcs,
crs=src_crs)
crs=src_crs,
copy=True)
src_dataset = src_mem.handle()
# If the source is a rasterio MultiBand, no copy necessary.
@ -429,7 +430,7 @@ def _reproject(
raise ValueError("Invalid destination shape")
dst_bidx = src_bidx
mem_raster = InMemoryRaster(image=destination, transform=format_transform(dst_transform), crs=dst_crs)
mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs)
dst_dataset = mem_raster.handle()
if dst_alpha: