From ee49f462f94fd6057c9057d5fb0cda22fca5c418 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Tue, 19 Oct 2021 20:09:04 -0600 Subject: [PATCH] PR 2301 and follow up (#2318) * Add in memory raster that subclasses DatasetBase. * Remove unused variables. * Add r+ to modes setting georeferencing. * Fix dtype argument. * Use InMemoryRasterArray * Use InMemoryRasterArray in warp. * Eliminate unnecessary copy. * Add missing word. * Use InMemoryRasterArray in fillnodata * Add array interface method. * Resolve fillnodata test failure. * Remove unnecessary copy. * Cleaner array handling. Make sure _array is always an array, but only copy when needed/wanted. * Rename InMemoryRaster to MemoryDataset. * Add internal use only comment to MemoryDataset. * Follow ups on #2301 * Fix parameter type in docstring and whitespace Co-authored-by: Ryan Grout --- CHANGES.txt | 2 + rasterio/_base.pyx | 2 +- rasterio/_features.pyx | 36 +++--- rasterio/_fill.pyx | 13 ++- rasterio/_io.pxd | 12 +- rasterio/_io.pyx | 250 +++++++++-------------------------------- rasterio/_warp.pyx | 13 ++- 7 files changed, 94 insertions(+), 234 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index e94eed9b..76c27f63 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -9,6 +9,8 @@ Changes New features: +- The InMemoryRaster class in rasterio._io has been removed and replaced by a + more direct and efficient wrapper around numpy arrays (#2301). - Add support for PROJJSON based interchange for CRS (#2212). CRS.to_dict(projjson=True) returns a PROJJSON style dict and CRS.from_dict() will accept a PROJJSON style dict. PROJJSON text is accepted by diff --git a/rasterio/_base.pyx b/rasterio/_base.pyx index b227fc7b..4509be6b 100644 --- a/rasterio/_base.pyx +++ b/rasterio/_base.pyx @@ -412,7 +412,7 @@ cdef class DatasetBase: if err == GDALError.failure and not self._has_gcps_or_rpcs(): warnings.warn( ("Dataset has no geotransform, gcps, or rpcs. " - "The identity matrix be returned."), + "The identity matrix will be returned."), NotGeoreferencedWarning) return [gt[i] for i in range(6)] diff --git a/rasterio/_features.pyx b/rasterio/_features.pyx index 3ebc2246..b88d93be 100644 --- a/rasterio/_features.pyx +++ b/rasterio/_features.pyx @@ -9,7 +9,7 @@ from rasterio.dtypes import _getnpdtype from rasterio.enums import MergeAlg from rasterio._err cimport exc_wrap_int, exc_wrap_pointer -from rasterio._io cimport DatasetReaderBase, InMemoryRaster, io_auto +from rasterio._io cimport DatasetReaderBase, MemoryDataset, io_auto log = logging.getLogger(__name__) @@ -54,8 +54,8 @@ def _shapes(image, mask, connectivity, transform): cdef OGRLayerH layer = NULL cdef OGRFieldDefnH fielddefn = NULL cdef char **options = NULL - cdef InMemoryRaster mem_ds = None - cdef InMemoryRaster mask_ds = None + cdef MemoryDataset mem_ds = None + cdef MemoryDataset mask_ds = None cdef ShapeIterator shape_iter = None cdef int fieldtp @@ -74,7 +74,7 @@ def _shapes(image, mask, connectivity, transform): try: if dtypes.is_ndarray(image): - mem_ds = InMemoryRaster(image=image, transform=transform) + mem_ds = MemoryDataset(image, transform=transform) band = mem_ds.band(1) elif isinstance(image, tuple): rdr = image.ds @@ -92,7 +92,7 @@ def _shapes(image, mask, connectivity, transform): if dtypes.is_ndarray(mask): # A boolean mask must be converted to uint8 for GDAL - mask_ds = InMemoryRaster(image=mask.astype('uint8'), + mask_ds = MemoryDataset(mask.astype('uint8'), transform=transform) maskband = mask_ds.band(1) elif isinstance(mask, tuple): @@ -172,9 +172,9 @@ def _sieve(image, size, out, mask, connectivity): cdef int retval cdef int rows cdef int cols - cdef InMemoryRaster in_mem_ds = None - cdef InMemoryRaster out_mem_ds = None - cdef InMemoryRaster mask_mem_ds = None + cdef MemoryDataset in_mem_ds = None + cdef MemoryDataset out_mem_ds = None + cdef MemoryDataset mask_mem_ds = None cdef GDALRasterBandH in_band = NULL cdef GDALRasterBandH out_band = NULL cdef GDALRasterBandH mask_band = NULL @@ -206,7 +206,7 @@ def _sieve(image, size, out, mask, connectivity): try: if dtypes.is_ndarray(image): - in_mem_ds = InMemoryRaster(image=image) + in_mem_ds = MemoryDataset(image) in_band = in_mem_ds.band(1) elif isinstance(image, tuple): rdr = image.ds @@ -216,7 +216,7 @@ def _sieve(image, size, out, mask, connectivity): if dtypes.is_ndarray(out): log.debug("out array: %r", out) - out_mem_ds = InMemoryRaster(image=out) + out_mem_ds = MemoryDataset(out) out_band = out_mem_ds.band(1) elif isinstance(out, tuple): udr = out.ds @@ -234,7 +234,7 @@ def _sieve(image, size, out, mask, connectivity): if dtypes.is_ndarray(mask): # A boolean mask must be converted to uint8 for GDAL - mask_mem_ds = InMemoryRaster(image=mask.astype('uint8')) + mask_mem_ds = MemoryDataset(mask.astype('uint8')) mask_band = mask_mem_ds.band(1) elif isinstance(mask, tuple): @@ -319,7 +319,8 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg): cdef OGRGeometryH *geoms = NULL cdef char **options = NULL cdef double *pixel_values = NULL - cdef InMemoryRaster mem = None + cdef MemoryDataset mem = None + cdef int *band_ids = NULL try: if all_touched: @@ -346,20 +347,21 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg): geometry, i, value) # TODO: is a vsimem file more memory efficient? - with InMemoryRaster(image=image, transform=transform) as mem: + with MemoryDataset(image, transform=transform) as mem: + band_ids = CPLMalloc(mem.count*sizeof(int)) + for i in range(mem.count): + band_ids[i] = i + 1 exc_wrap_int( GDALRasterizeGeometries( - mem.handle(), 1, mem.band_ids, num_geoms, geoms, NULL, + mem.handle(), 1, band_ids, num_geoms, geoms, NULL, NULL, pixel_values, options, NULL, NULL)) - # Read in-memory data back into image - image = mem.read() - finally: for i in range(num_geoms): _deleteOgrGeom(geoms[i]) CPLFree(geoms) CPLFree(pixel_values) + CPLFree(band_ids) if options: CSLDestroy(options) diff --git a/rasterio/_fill.pyx b/rasterio/_fill.pyx index 4a3eb332..3a864c13 100644 --- a/rasterio/_fill.pyx +++ b/rasterio/_fill.pyx @@ -4,8 +4,9 @@ include "gdal.pxi" +import numpy as np from rasterio._err cimport exc_wrap_int -from rasterio._io cimport InMemoryRaster +from rasterio._io cimport MemoryDataset def _fillnodata(image, mask, double max_search_distance=100.0, @@ -13,24 +14,24 @@ def _fillnodata(image, mask, double max_search_distance=100.0, cdef GDALRasterBandH image_band = NULL cdef GDALRasterBandH mask_band = NULL cdef char **alg_options = NULL - cdef InMemoryRaster image_dataset = None - cdef InMemoryRaster mask_dataset = None + cdef MemoryDataset image_dataset = None + cdef MemoryDataset mask_dataset = None try: # copy numpy ndarray into an in-memory dataset. - image_dataset = InMemoryRaster(image) + image_dataset = MemoryDataset(image) image_band = image_dataset.band(1) if mask is not None: mask_cast = mask.astype('uint8') - mask_dataset = InMemoryRaster(mask_cast) + mask_dataset = MemoryDataset(mask_cast) mask_band = mask_dataset.band(1) alg_options = CSLSetNameValue(alg_options, "TEMP_FILE_DRIVER", "MEM") exc_wrap_int( GDALFillNodata(image_band, mask_band, max_search_distance, 0, smoothing_iterations, alg_options, NULL, NULL)) - return image_dataset.read() + return np.asarray(image_dataset) finally: if image_dataset is not None: image_dataset.close() diff --git a/rasterio/_io.pxd b/rasterio/_io.pxd index 56d35cc2..9b330aa8 100644 --- a/rasterio/_io.pxd +++ b/rasterio/_io.pxd @@ -21,16 +21,8 @@ cdef class BufferedDatasetWriterBase(DatasetWriterBase): pass -cdef class InMemoryRaster: - cdef GDALDatasetH _hds - cdef double gdal_transform[6] - cdef int* band_ids - cdef np.ndarray _image - cdef object crs - cdef object transform # this is an Affine object. - - cdef GDALDatasetH handle(self) except NULL - cdef GDALRasterBandH band(self, int) except NULL +cdef class MemoryDataset(DatasetWriterBase): + cdef np.ndarray _array cdef class MemoryFileBase: diff --git a/rasterio/_io.pyx b/rasterio/_io.pyx index 6aa50074..643b09a4 100644 --- a/rasterio/_io.pyx +++ b/rasterio/_io.pyx @@ -25,7 +25,7 @@ from rasterio.errors import ( NotGeoreferencedWarning, NodataShadowWarning, WindowError, UnsupportedOperation, OverviewCreationError, RasterBlockError, InvalidArrayError ) -from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype +from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype, _gdal_typename from rasterio.sample import sample_gen from rasterio.transform import Affine from rasterio.path import parse_path, UnparsedPath @@ -1904,208 +1904,70 @@ cdef class DatasetWriterBase(DatasetReaderBase): self.update_tags(ns='RPC', **rpcs) self._rpcs = None -cdef class InMemoryRaster: - """ - Class that manages a single-band in memory GDAL raster dataset. Data type - is determined from the data type of the input numpy 2D array (image), and - must be one of the data types supported by GDAL - (see rasterio.dtypes.dtype_rev). Data are populated at create time from - the 2D array passed in. - Use the 'with' pattern to instantiate this class for automatic closing - of the memory dataset. +cdef class MemoryDataset(DatasetWriterBase): + def __init__(self, arr, transform=None, gcps=None, rpcs=None, crs=None, copy=False): + """Dataset wrapped around in-memory array. - This class includes attributes that are intended to be passed into GDAL - functions: - self.dataset - self.band - self.band_ids (single element array with band ID of this dataset's band) - self.transform (GDAL compatible transform array) + This class is intended for internal use only within rasterio to + support IO with GDAL, where a Dataset object is needed. - This class is only intended for internal use within rasterio to support - IO with GDAL. Other memory based operations should use numpy arrays. - """ - def __cinit__(self): - self._hds = NULL - self.band_ids = NULL - self._image = None - self.crs = None - self.transform = None + MemoryDataset supports the NumPy array interface. + + Parameters + ---------- + arr : ndarray + Array to use for dataset + transform : Transform + Dataset transform + gcps : list + List of GroundControlPoints, a CRS + rpcs : list + Dataset rational polynomial coefficients + crs : CRS + Dataset coordinate reference system + copy : bool, optional + Create an internal copy of the array. If set to False, + caller must make sure that arr is valid while this object + lives. - def __init__(self, image=None, dtype='uint8', count=1, width=None, - height=None, transform=None, gcps=None, rpcs=None, crs=None): """ - Create in-memory raster dataset, and fill its bands with the - arrays in image. + self._array = np.array(arr, copy=copy) + dtype = self._array.dtype - An empty in-memory raster with no memory allocated to bands, - e.g. for use in _calculate_default_transform(), can be created - by passing dtype, count, width, and height instead. + if self._array.ndim == 2: + count = 1 + height, width = arr.shape + elif self._array.ndim == 3: + count, height, width = arr.shape + else: + raise ValueError("arr must be 2D or 3D array") - :param image: 2D numpy array. Must be of supported data type - (see rasterio.dtypes.dtype_rev) - :param transform: Affine transform object - """ - cdef int i = 0 # avoids Cython warning in for loop below - cdef char *srcwkt = NULL - cdef OGRSpatialReferenceH osr = NULL - cdef GDALDriverH mdriver = NULL - cdef GDAL_GCP *gcplist = NULL - cdef char **options = NULL - cdef char **papszMD = NULL + arr_info = self._array.__array_interface__ + info = { + "DATAPOINTER": arr_info["data"][0], + "PIXELS": width, + "LINES": height, + "BANDS": count, + "DATATYPE": _gdal_typename(arr.dtype.name) + } + dataset_options = ",".join(f"{name}={val}" for name, val in info.items()) + datasetname = f"MEM:::{dataset_options}" - if image is not None: - if image.ndim == 3: - count, height, width = image.shape - elif image.ndim == 2: - count = 1 - height, width = image.shape - dtype = image.dtype.name + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + super().__init__(parse_path(datasetname), "r+") + if crs is not None: + self.crs = crs + if transform is not None: + self.transform = transform + if gcps is not None and crs is not None: + self.gcps = (gcps, crs) + if rpcs is not None: + self.rpcs = rpcs - if height is None or height == 0: - raise ValueError("height must be > 0") - - if width is None or width == 0: - raise ValueError("width must be > 0") - - self.band_ids = CPLMalloc(count*sizeof(int)) - for i in range(1, count + 1): - self.band_ids[i-1] = i - - try: - memdriver = exc_wrap_pointer(GDALGetDriverByName("MEM")) - except Exception: - raise DriverRegistrationError( - "'MEM' driver not found. Check that this call is contained " - "in a `with rasterio.Env()` or `with rasterio.open()` " - "block.") - - if _getnpdtype(dtype) == _getnpdtype("int8"): - options = CSLSetNameValue(options, 'PIXELTYPE', 'SIGNEDBYTE') - - datasetname = str(uuid4()).encode('utf-8') - self._hds = exc_wrap_pointer( - GDALCreate(memdriver, datasetname, width, height, - count, dtypes.dtype_rev[dtype], options)) - - if transform is not None: - self.transform = transform - gdal_transform = transform.to_gdal() - for i in range(6): - self.gdal_transform[i] = gdal_transform[i] - exc_wrap_int(GDALSetGeoTransform(self._hds, self.gdal_transform)) - if crs: - osr = _osr_from_crs(crs) - try: - OSRExportToWkt(osr, &srcwkt) - exc_wrap_int(GDALSetProjection(self._hds, srcwkt)) - log.debug("Set CRS on temp dataset: %s", srcwkt) - finally: - CPLFree(srcwkt) - _safe_osr_release(osr) - - elif gcps and crs: - try: - gcplist = CPLMalloc(len(gcps) * sizeof(GDAL_GCP)) - for i, obj in enumerate(gcps): - ident = str(i).encode('utf-8') - info = "".encode('utf-8') - gcplist[i].pszId = ident - gcplist[i].pszInfo = info - gcplist[i].dfGCPPixel = obj.col - gcplist[i].dfGCPLine = obj.row - gcplist[i].dfGCPX = obj.x - gcplist[i].dfGCPY = obj.y - gcplist[i].dfGCPZ = obj.z or 0.0 - - osr = _osr_from_crs(crs) - OSRExportToWkt(osr, &srcwkt) - exc_wrap_int(GDALSetGCPs(self._hds, len(gcps), gcplist, srcwkt)) - finally: - CPLFree(gcplist) - CPLFree(srcwkt) - _safe_osr_release(osr) - elif rpcs: - try: - if hasattr(rpcs, 'to_gdal'): - rpcs = rpcs.to_gdal() - for key, val in rpcs.items(): - key = key.upper().encode('utf-8') - val = str(val).encode('utf-8') - papszMD = CSLSetNameValue( - papszMD, key, val) - exc_wrap_int(GDALSetMetadata(self._hds, papszMD, "RPC")) - finally: - CSLDestroy(papszMD) - - if options != NULL: - CSLDestroy(options) - - if image is not None: - self.write(image) - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() - - def __dealloc__(self): - if self.band_ids != NULL: - CPLFree(self.band_ids) - self.band_ids = NULL - - cdef GDALDatasetH handle(self) except NULL: - """Return the object's GDAL dataset handle""" - return self._hds - - cdef GDALRasterBandH band(self, int bidx) except NULL: - """Return a GDAL raster band handle""" - cdef GDALRasterBandH band = NULL - - try: - band = exc_wrap_pointer(GDALGetRasterBand(self._hds, bidx)) - except CPLE_IllegalArgError as exc: - raise IndexError(str(exc)) - - # Don't get here. - if band == NULL: - raise ValueError("NULL band") - - return band - - def close(self): - if self._hds != NULL: - GDALClose(self._hds) - self._hds = NULL - - def read(self): - - if self._image is None: - raise RasterioIOError("You need to write data before you can read the data.") - - try: - if self._image.ndim == 2: - io_auto(self._image, self.band(1), False) - else: - io_auto(self._image, self._hds, False) - - except CPLE_BaseError as cplerr: - raise RasterioIOError("Read or write failed. {}".format(cplerr)) - - return self._image - - def write(self, np.ndarray image): - self._image = image - - try: - if image.ndim == 2: - io_auto(self._image, self.band(1), True) - else: - io_auto(self._image, self._hds, True) - - except CPLE_BaseError as cplerr: - raise RasterioIOError("Read or write failed. {}".format(cplerr)) + def __array__(self): + return self._array cdef class BufferedDatasetWriterBase(DatasetWriterBase): diff --git a/rasterio/_warp.pyx b/rasterio/_warp.pyx index 9695fe38..5a16a72d 100644 --- a/rasterio/_warp.pyx +++ b/rasterio/_warp.pyx @@ -37,7 +37,7 @@ from libc.math cimport HUGE_VAL from rasterio._base cimport _osr_from_crs, get_driver_name, _safe_osr_release from rasterio._err cimport exc_wrap_pointer, exc_wrap_int from rasterio._io cimport ( - DatasetReaderBase, InMemoryRaster, in_dtype_range, io_auto) + DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto) from rasterio._features cimport GeomBuilder, OGRGeomBuilder @@ -360,8 +360,8 @@ def _reproject( in_transform = in_transform.translation(eps, eps) return in_transform - cdef InMemoryRaster mem_raster = None - cdef InMemoryRaster src_mem = None + cdef MemoryDataset mem_raster = None + cdef MemoryDataset src_mem = None try: @@ -381,11 +381,12 @@ def _reproject( source = source.reshape(1, *source.shape) src_count = source.shape[0] src_bidx = range(1, src_count + 1) - src_mem = InMemoryRaster(image=source, + src_mem = MemoryDataset(source, transform=format_transform(src_transform), gcps=gcps, rpcs=rpcs, - crs=src_crs) + crs=src_crs, + copy=True) src_dataset = src_mem.handle() # If the source is a rasterio MultiBand, no copy necessary. @@ -429,7 +430,7 @@ def _reproject( raise ValueError("Invalid destination shape") dst_bidx = src_bidx - mem_raster = InMemoryRaster(image=destination, transform=format_transform(dst_transform), crs=dst_crs) + mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs) dst_dataset = mem_raster.handle() if dst_alpha: