mirror of
https://github.com/rasterio/rasterio.git
synced 2026-02-01 14:34:43 +00:00
PR 2301 and follow up (#2318)
* Add in memory raster that subclasses DatasetBase. * Remove unused variables. * Add r+ to modes setting georeferencing. * Fix dtype argument. * Use InMemoryRasterArray * Use InMemoryRasterArray in warp. * Eliminate unnecessary copy. * Add missing word. * Use InMemoryRasterArray in fillnodata * Add array interface method. * Resolve fillnodata test failure. * Remove unnecessary copy. * Cleaner array handling. Make sure _array is always an array, but only copy when needed/wanted. * Rename InMemoryRaster to MemoryDataset. * Add internal use only comment to MemoryDataset. * Follow ups on #2301 * Fix parameter type in docstring and whitespace Co-authored-by: Ryan Grout <ryan@ryangrout.org>
This commit is contained in:
parent
fa4b5ae804
commit
ee49f462f9
@ -9,6 +9,8 @@ Changes
|
||||
|
||||
New features:
|
||||
|
||||
- The InMemoryRaster class in rasterio._io has been removed and replaced by a
|
||||
more direct and efficient wrapper around numpy arrays (#2301).
|
||||
- Add support for PROJJSON based interchange for CRS (#2212).
|
||||
CRS.to_dict(projjson=True) returns a PROJJSON style dict and CRS.from_dict()
|
||||
will accept a PROJJSON style dict. PROJJSON text is accepted by
|
||||
|
||||
@ -412,7 +412,7 @@ cdef class DatasetBase:
|
||||
if err == GDALError.failure and not self._has_gcps_or_rpcs():
|
||||
warnings.warn(
|
||||
("Dataset has no geotransform, gcps, or rpcs. "
|
||||
"The identity matrix be returned."),
|
||||
"The identity matrix will be returned."),
|
||||
NotGeoreferencedWarning)
|
||||
|
||||
return [gt[i] for i in range(6)]
|
||||
|
||||
@ -9,7 +9,7 @@ from rasterio.dtypes import _getnpdtype
|
||||
from rasterio.enums import MergeAlg
|
||||
|
||||
from rasterio._err cimport exc_wrap_int, exc_wrap_pointer
|
||||
from rasterio._io cimport DatasetReaderBase, InMemoryRaster, io_auto
|
||||
from rasterio._io cimport DatasetReaderBase, MemoryDataset, io_auto
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -54,8 +54,8 @@ def _shapes(image, mask, connectivity, transform):
|
||||
cdef OGRLayerH layer = NULL
|
||||
cdef OGRFieldDefnH fielddefn = NULL
|
||||
cdef char **options = NULL
|
||||
cdef InMemoryRaster mem_ds = None
|
||||
cdef InMemoryRaster mask_ds = None
|
||||
cdef MemoryDataset mem_ds = None
|
||||
cdef MemoryDataset mask_ds = None
|
||||
cdef ShapeIterator shape_iter = None
|
||||
cdef int fieldtp
|
||||
|
||||
@ -74,7 +74,7 @@ def _shapes(image, mask, connectivity, transform):
|
||||
try:
|
||||
|
||||
if dtypes.is_ndarray(image):
|
||||
mem_ds = InMemoryRaster(image=image, transform=transform)
|
||||
mem_ds = MemoryDataset(image, transform=transform)
|
||||
band = mem_ds.band(1)
|
||||
elif isinstance(image, tuple):
|
||||
rdr = image.ds
|
||||
@ -92,7 +92,7 @@ def _shapes(image, mask, connectivity, transform):
|
||||
|
||||
if dtypes.is_ndarray(mask):
|
||||
# A boolean mask must be converted to uint8 for GDAL
|
||||
mask_ds = InMemoryRaster(image=mask.astype('uint8'),
|
||||
mask_ds = MemoryDataset(mask.astype('uint8'),
|
||||
transform=transform)
|
||||
maskband = mask_ds.band(1)
|
||||
elif isinstance(mask, tuple):
|
||||
@ -172,9 +172,9 @@ def _sieve(image, size, out, mask, connectivity):
|
||||
cdef int retval
|
||||
cdef int rows
|
||||
cdef int cols
|
||||
cdef InMemoryRaster in_mem_ds = None
|
||||
cdef InMemoryRaster out_mem_ds = None
|
||||
cdef InMemoryRaster mask_mem_ds = None
|
||||
cdef MemoryDataset in_mem_ds = None
|
||||
cdef MemoryDataset out_mem_ds = None
|
||||
cdef MemoryDataset mask_mem_ds = None
|
||||
cdef GDALRasterBandH in_band = NULL
|
||||
cdef GDALRasterBandH out_band = NULL
|
||||
cdef GDALRasterBandH mask_band = NULL
|
||||
@ -206,7 +206,7 @@ def _sieve(image, size, out, mask, connectivity):
|
||||
try:
|
||||
|
||||
if dtypes.is_ndarray(image):
|
||||
in_mem_ds = InMemoryRaster(image=image)
|
||||
in_mem_ds = MemoryDataset(image)
|
||||
in_band = in_mem_ds.band(1)
|
||||
elif isinstance(image, tuple):
|
||||
rdr = image.ds
|
||||
@ -216,7 +216,7 @@ def _sieve(image, size, out, mask, connectivity):
|
||||
|
||||
if dtypes.is_ndarray(out):
|
||||
log.debug("out array: %r", out)
|
||||
out_mem_ds = InMemoryRaster(image=out)
|
||||
out_mem_ds = MemoryDataset(out)
|
||||
out_band = out_mem_ds.band(1)
|
||||
elif isinstance(out, tuple):
|
||||
udr = out.ds
|
||||
@ -234,7 +234,7 @@ def _sieve(image, size, out, mask, connectivity):
|
||||
|
||||
if dtypes.is_ndarray(mask):
|
||||
# A boolean mask must be converted to uint8 for GDAL
|
||||
mask_mem_ds = InMemoryRaster(image=mask.astype('uint8'))
|
||||
mask_mem_ds = MemoryDataset(mask.astype('uint8'))
|
||||
mask_band = mask_mem_ds.band(1)
|
||||
|
||||
elif isinstance(mask, tuple):
|
||||
@ -319,7 +319,8 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg):
|
||||
cdef OGRGeometryH *geoms = NULL
|
||||
cdef char **options = NULL
|
||||
cdef double *pixel_values = NULL
|
||||
cdef InMemoryRaster mem = None
|
||||
cdef MemoryDataset mem = None
|
||||
cdef int *band_ids = NULL
|
||||
|
||||
try:
|
||||
if all_touched:
|
||||
@ -346,20 +347,21 @@ def _rasterize(shapes, image, transform, all_touched, merge_alg):
|
||||
geometry, i, value)
|
||||
|
||||
# TODO: is a vsimem file more memory efficient?
|
||||
with InMemoryRaster(image=image, transform=transform) as mem:
|
||||
with MemoryDataset(image, transform=transform) as mem:
|
||||
band_ids = <int *>CPLMalloc(mem.count*sizeof(int))
|
||||
for i in range(mem.count):
|
||||
band_ids[i] = i + 1
|
||||
exc_wrap_int(
|
||||
GDALRasterizeGeometries(
|
||||
mem.handle(), 1, mem.band_ids, num_geoms, geoms, NULL,
|
||||
mem.handle(), 1, band_ids, num_geoms, geoms, NULL,
|
||||
NULL, pixel_values, options, NULL, NULL))
|
||||
|
||||
# Read in-memory data back into image
|
||||
image = mem.read()
|
||||
|
||||
finally:
|
||||
for i in range(num_geoms):
|
||||
_deleteOgrGeom(geoms[i])
|
||||
CPLFree(geoms)
|
||||
CPLFree(pixel_values)
|
||||
CPLFree(band_ids)
|
||||
if options:
|
||||
CSLDestroy(options)
|
||||
|
||||
|
||||
@ -4,8 +4,9 @@
|
||||
|
||||
include "gdal.pxi"
|
||||
|
||||
import numpy as np
|
||||
from rasterio._err cimport exc_wrap_int
|
||||
from rasterio._io cimport InMemoryRaster
|
||||
from rasterio._io cimport MemoryDataset
|
||||
|
||||
|
||||
def _fillnodata(image, mask, double max_search_distance=100.0,
|
||||
@ -13,24 +14,24 @@ def _fillnodata(image, mask, double max_search_distance=100.0,
|
||||
cdef GDALRasterBandH image_band = NULL
|
||||
cdef GDALRasterBandH mask_band = NULL
|
||||
cdef char **alg_options = NULL
|
||||
cdef InMemoryRaster image_dataset = None
|
||||
cdef InMemoryRaster mask_dataset = None
|
||||
cdef MemoryDataset image_dataset = None
|
||||
cdef MemoryDataset mask_dataset = None
|
||||
|
||||
try:
|
||||
# copy numpy ndarray into an in-memory dataset.
|
||||
image_dataset = InMemoryRaster(image)
|
||||
image_dataset = MemoryDataset(image)
|
||||
image_band = image_dataset.band(1)
|
||||
|
||||
if mask is not None:
|
||||
mask_cast = mask.astype('uint8')
|
||||
mask_dataset = InMemoryRaster(mask_cast)
|
||||
mask_dataset = MemoryDataset(mask_cast)
|
||||
mask_band = mask_dataset.band(1)
|
||||
|
||||
alg_options = CSLSetNameValue(alg_options, "TEMP_FILE_DRIVER", "MEM")
|
||||
exc_wrap_int(
|
||||
GDALFillNodata(image_band, mask_band, max_search_distance, 0,
|
||||
smoothing_iterations, alg_options, NULL, NULL))
|
||||
return image_dataset.read()
|
||||
return np.asarray(image_dataset)
|
||||
finally:
|
||||
if image_dataset is not None:
|
||||
image_dataset.close()
|
||||
|
||||
@ -21,16 +21,8 @@ cdef class BufferedDatasetWriterBase(DatasetWriterBase):
|
||||
pass
|
||||
|
||||
|
||||
cdef class InMemoryRaster:
|
||||
cdef GDALDatasetH _hds
|
||||
cdef double gdal_transform[6]
|
||||
cdef int* band_ids
|
||||
cdef np.ndarray _image
|
||||
cdef object crs
|
||||
cdef object transform # this is an Affine object.
|
||||
|
||||
cdef GDALDatasetH handle(self) except NULL
|
||||
cdef GDALRasterBandH band(self, int) except NULL
|
||||
cdef class MemoryDataset(DatasetWriterBase):
|
||||
cdef np.ndarray _array
|
||||
|
||||
|
||||
cdef class MemoryFileBase:
|
||||
|
||||
250
rasterio/_io.pyx
250
rasterio/_io.pyx
@ -25,7 +25,7 @@ from rasterio.errors import (
|
||||
NotGeoreferencedWarning, NodataShadowWarning, WindowError,
|
||||
UnsupportedOperation, OverviewCreationError, RasterBlockError, InvalidArrayError
|
||||
)
|
||||
from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype
|
||||
from rasterio.dtypes import is_ndarray, _is_complex_int, _getnpdtype, _gdal_typename
|
||||
from rasterio.sample import sample_gen
|
||||
from rasterio.transform import Affine
|
||||
from rasterio.path import parse_path, UnparsedPath
|
||||
@ -1904,208 +1904,70 @@ cdef class DatasetWriterBase(DatasetReaderBase):
|
||||
self.update_tags(ns='RPC', **rpcs)
|
||||
self._rpcs = None
|
||||
|
||||
cdef class InMemoryRaster:
|
||||
"""
|
||||
Class that manages a single-band in memory GDAL raster dataset. Data type
|
||||
is determined from the data type of the input numpy 2D array (image), and
|
||||
must be one of the data types supported by GDAL
|
||||
(see rasterio.dtypes.dtype_rev). Data are populated at create time from
|
||||
the 2D array passed in.
|
||||
|
||||
Use the 'with' pattern to instantiate this class for automatic closing
|
||||
of the memory dataset.
|
||||
cdef class MemoryDataset(DatasetWriterBase):
|
||||
def __init__(self, arr, transform=None, gcps=None, rpcs=None, crs=None, copy=False):
|
||||
"""Dataset wrapped around in-memory array.
|
||||
|
||||
This class includes attributes that are intended to be passed into GDAL
|
||||
functions:
|
||||
self.dataset
|
||||
self.band
|
||||
self.band_ids (single element array with band ID of this dataset's band)
|
||||
self.transform (GDAL compatible transform array)
|
||||
This class is intended for internal use only within rasterio to
|
||||
support IO with GDAL, where a Dataset object is needed.
|
||||
|
||||
This class is only intended for internal use within rasterio to support
|
||||
IO with GDAL. Other memory based operations should use numpy arrays.
|
||||
"""
|
||||
def __cinit__(self):
|
||||
self._hds = NULL
|
||||
self.band_ids = NULL
|
||||
self._image = None
|
||||
self.crs = None
|
||||
self.transform = None
|
||||
MemoryDataset supports the NumPy array interface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arr : ndarray
|
||||
Array to use for dataset
|
||||
transform : Transform
|
||||
Dataset transform
|
||||
gcps : list
|
||||
List of GroundControlPoints, a CRS
|
||||
rpcs : list
|
||||
Dataset rational polynomial coefficients
|
||||
crs : CRS
|
||||
Dataset coordinate reference system
|
||||
copy : bool, optional
|
||||
Create an internal copy of the array. If set to False,
|
||||
caller must make sure that arr is valid while this object
|
||||
lives.
|
||||
|
||||
def __init__(self, image=None, dtype='uint8', count=1, width=None,
|
||||
height=None, transform=None, gcps=None, rpcs=None, crs=None):
|
||||
"""
|
||||
Create in-memory raster dataset, and fill its bands with the
|
||||
arrays in image.
|
||||
self._array = np.array(arr, copy=copy)
|
||||
dtype = self._array.dtype
|
||||
|
||||
An empty in-memory raster with no memory allocated to bands,
|
||||
e.g. for use in _calculate_default_transform(), can be created
|
||||
by passing dtype, count, width, and height instead.
|
||||
if self._array.ndim == 2:
|
||||
count = 1
|
||||
height, width = arr.shape
|
||||
elif self._array.ndim == 3:
|
||||
count, height, width = arr.shape
|
||||
else:
|
||||
raise ValueError("arr must be 2D or 3D array")
|
||||
|
||||
:param image: 2D numpy array. Must be of supported data type
|
||||
(see rasterio.dtypes.dtype_rev)
|
||||
:param transform: Affine transform object
|
||||
"""
|
||||
cdef int i = 0 # avoids Cython warning in for loop below
|
||||
cdef char *srcwkt = NULL
|
||||
cdef OGRSpatialReferenceH osr = NULL
|
||||
cdef GDALDriverH mdriver = NULL
|
||||
cdef GDAL_GCP *gcplist = NULL
|
||||
cdef char **options = NULL
|
||||
cdef char **papszMD = NULL
|
||||
arr_info = self._array.__array_interface__
|
||||
info = {
|
||||
"DATAPOINTER": arr_info["data"][0],
|
||||
"PIXELS": width,
|
||||
"LINES": height,
|
||||
"BANDS": count,
|
||||
"DATATYPE": _gdal_typename(arr.dtype.name)
|
||||
}
|
||||
dataset_options = ",".join(f"{name}={val}" for name, val in info.items())
|
||||
datasetname = f"MEM:::{dataset_options}"
|
||||
|
||||
if image is not None:
|
||||
if image.ndim == 3:
|
||||
count, height, width = image.shape
|
||||
elif image.ndim == 2:
|
||||
count = 1
|
||||
height, width = image.shape
|
||||
dtype = image.dtype.name
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
super().__init__(parse_path(datasetname), "r+")
|
||||
if crs is not None:
|
||||
self.crs = crs
|
||||
if transform is not None:
|
||||
self.transform = transform
|
||||
if gcps is not None and crs is not None:
|
||||
self.gcps = (gcps, crs)
|
||||
if rpcs is not None:
|
||||
self.rpcs = rpcs
|
||||
|
||||
if height is None or height == 0:
|
||||
raise ValueError("height must be > 0")
|
||||
|
||||
if width is None or width == 0:
|
||||
raise ValueError("width must be > 0")
|
||||
|
||||
self.band_ids = <int *>CPLMalloc(count*sizeof(int))
|
||||
for i in range(1, count + 1):
|
||||
self.band_ids[i-1] = i
|
||||
|
||||
try:
|
||||
memdriver = exc_wrap_pointer(GDALGetDriverByName("MEM"))
|
||||
except Exception:
|
||||
raise DriverRegistrationError(
|
||||
"'MEM' driver not found. Check that this call is contained "
|
||||
"in a `with rasterio.Env()` or `with rasterio.open()` "
|
||||
"block.")
|
||||
|
||||
if _getnpdtype(dtype) == _getnpdtype("int8"):
|
||||
options = CSLSetNameValue(options, 'PIXELTYPE', 'SIGNEDBYTE')
|
||||
|
||||
datasetname = str(uuid4()).encode('utf-8')
|
||||
self._hds = exc_wrap_pointer(
|
||||
GDALCreate(memdriver, <const char *>datasetname, width, height,
|
||||
count, <GDALDataType>dtypes.dtype_rev[dtype], options))
|
||||
|
||||
if transform is not None:
|
||||
self.transform = transform
|
||||
gdal_transform = transform.to_gdal()
|
||||
for i in range(6):
|
||||
self.gdal_transform[i] = gdal_transform[i]
|
||||
exc_wrap_int(GDALSetGeoTransform(self._hds, self.gdal_transform))
|
||||
if crs:
|
||||
osr = _osr_from_crs(crs)
|
||||
try:
|
||||
OSRExportToWkt(osr, &srcwkt)
|
||||
exc_wrap_int(GDALSetProjection(self._hds, srcwkt))
|
||||
log.debug("Set CRS on temp dataset: %s", srcwkt)
|
||||
finally:
|
||||
CPLFree(srcwkt)
|
||||
_safe_osr_release(osr)
|
||||
|
||||
elif gcps and crs:
|
||||
try:
|
||||
gcplist = <GDAL_GCP *>CPLMalloc(len(gcps) * sizeof(GDAL_GCP))
|
||||
for i, obj in enumerate(gcps):
|
||||
ident = str(i).encode('utf-8')
|
||||
info = "".encode('utf-8')
|
||||
gcplist[i].pszId = ident
|
||||
gcplist[i].pszInfo = info
|
||||
gcplist[i].dfGCPPixel = obj.col
|
||||
gcplist[i].dfGCPLine = obj.row
|
||||
gcplist[i].dfGCPX = obj.x
|
||||
gcplist[i].dfGCPY = obj.y
|
||||
gcplist[i].dfGCPZ = obj.z or 0.0
|
||||
|
||||
osr = _osr_from_crs(crs)
|
||||
OSRExportToWkt(osr, &srcwkt)
|
||||
exc_wrap_int(GDALSetGCPs(self._hds, len(gcps), gcplist, srcwkt))
|
||||
finally:
|
||||
CPLFree(gcplist)
|
||||
CPLFree(srcwkt)
|
||||
_safe_osr_release(osr)
|
||||
elif rpcs:
|
||||
try:
|
||||
if hasattr(rpcs, 'to_gdal'):
|
||||
rpcs = rpcs.to_gdal()
|
||||
for key, val in rpcs.items():
|
||||
key = key.upper().encode('utf-8')
|
||||
val = str(val).encode('utf-8')
|
||||
papszMD = CSLSetNameValue(
|
||||
papszMD, <const char *>key, <const char *>val)
|
||||
exc_wrap_int(GDALSetMetadata(self._hds, papszMD, "RPC"))
|
||||
finally:
|
||||
CSLDestroy(papszMD)
|
||||
|
||||
if options != NULL:
|
||||
CSLDestroy(options)
|
||||
|
||||
if image is not None:
|
||||
self.write(image)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.close()
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.band_ids != NULL:
|
||||
CPLFree(self.band_ids)
|
||||
self.band_ids = NULL
|
||||
|
||||
cdef GDALDatasetH handle(self) except NULL:
|
||||
"""Return the object's GDAL dataset handle"""
|
||||
return self._hds
|
||||
|
||||
cdef GDALRasterBandH band(self, int bidx) except NULL:
|
||||
"""Return a GDAL raster band handle"""
|
||||
cdef GDALRasterBandH band = NULL
|
||||
|
||||
try:
|
||||
band = exc_wrap_pointer(GDALGetRasterBand(self._hds, bidx))
|
||||
except CPLE_IllegalArgError as exc:
|
||||
raise IndexError(str(exc))
|
||||
|
||||
# Don't get here.
|
||||
if band == NULL:
|
||||
raise ValueError("NULL band")
|
||||
|
||||
return band
|
||||
|
||||
def close(self):
|
||||
if self._hds != NULL:
|
||||
GDALClose(self._hds)
|
||||
self._hds = NULL
|
||||
|
||||
def read(self):
|
||||
|
||||
if self._image is None:
|
||||
raise RasterioIOError("You need to write data before you can read the data.")
|
||||
|
||||
try:
|
||||
if self._image.ndim == 2:
|
||||
io_auto(self._image, self.band(1), False)
|
||||
else:
|
||||
io_auto(self._image, self._hds, False)
|
||||
|
||||
except CPLE_BaseError as cplerr:
|
||||
raise RasterioIOError("Read or write failed. {}".format(cplerr))
|
||||
|
||||
return self._image
|
||||
|
||||
def write(self, np.ndarray image):
|
||||
self._image = image
|
||||
|
||||
try:
|
||||
if image.ndim == 2:
|
||||
io_auto(self._image, self.band(1), True)
|
||||
else:
|
||||
io_auto(self._image, self._hds, True)
|
||||
|
||||
except CPLE_BaseError as cplerr:
|
||||
raise RasterioIOError("Read or write failed. {}".format(cplerr))
|
||||
def __array__(self):
|
||||
return self._array
|
||||
|
||||
|
||||
cdef class BufferedDatasetWriterBase(DatasetWriterBase):
|
||||
|
||||
@ -37,7 +37,7 @@ from libc.math cimport HUGE_VAL
|
||||
from rasterio._base cimport _osr_from_crs, get_driver_name, _safe_osr_release
|
||||
from rasterio._err cimport exc_wrap_pointer, exc_wrap_int
|
||||
from rasterio._io cimport (
|
||||
DatasetReaderBase, InMemoryRaster, in_dtype_range, io_auto)
|
||||
DatasetReaderBase, MemoryDataset, in_dtype_range, io_auto)
|
||||
from rasterio._features cimport GeomBuilder, OGRGeomBuilder
|
||||
|
||||
|
||||
@ -360,8 +360,8 @@ def _reproject(
|
||||
in_transform = in_transform.translation(eps, eps)
|
||||
return in_transform
|
||||
|
||||
cdef InMemoryRaster mem_raster = None
|
||||
cdef InMemoryRaster src_mem = None
|
||||
cdef MemoryDataset mem_raster = None
|
||||
cdef MemoryDataset src_mem = None
|
||||
|
||||
try:
|
||||
|
||||
@ -381,11 +381,12 @@ def _reproject(
|
||||
source = source.reshape(1, *source.shape)
|
||||
src_count = source.shape[0]
|
||||
src_bidx = range(1, src_count + 1)
|
||||
src_mem = InMemoryRaster(image=source,
|
||||
src_mem = MemoryDataset(source,
|
||||
transform=format_transform(src_transform),
|
||||
gcps=gcps,
|
||||
rpcs=rpcs,
|
||||
crs=src_crs)
|
||||
crs=src_crs,
|
||||
copy=True)
|
||||
src_dataset = src_mem.handle()
|
||||
|
||||
# If the source is a rasterio MultiBand, no copy necessary.
|
||||
@ -429,7 +430,7 @@ def _reproject(
|
||||
raise ValueError("Invalid destination shape")
|
||||
dst_bidx = src_bidx
|
||||
|
||||
mem_raster = InMemoryRaster(image=destination, transform=format_transform(dst_transform), crs=dst_crs)
|
||||
mem_raster = MemoryDataset(destination, transform=format_transform(dst_transform), crs=dst_crs)
|
||||
dst_dataset = mem_raster.handle()
|
||||
|
||||
if dst_alpha:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user