mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
fillnodata now returns ndarray, instead of modifying in-place
This commit is contained in:
parent
9a480d55dd
commit
4758a8aad4
@ -5,47 +5,65 @@
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from rasterio._io cimport InMemoryRaster
|
||||
from rasterio import dtypes
|
||||
from rasterio._err import cpl_errs
|
||||
from rasterio cimport _gdal, _io
|
||||
|
||||
from rasterio._io cimport InMemoryRaster
|
||||
|
||||
def _fillnodata(image, mask, double max_search_distance=100.0, int smoothing_iterations=0):
|
||||
cdef void *hband
|
||||
cdef void *hmaskband
|
||||
cdef char **options = NULL
|
||||
cdef _io.RasterReader rdr
|
||||
cdef _io.RasterReader mrdr
|
||||
cdef InMemoryRaster mem_ds = None
|
||||
cdef InMemoryRaster mask_ds = None
|
||||
|
||||
cdef void *memdriver = _gdal.GDALGetDriverByName("MEM")
|
||||
cdef void *image_dataset
|
||||
cdef void *image_band
|
||||
cdef void *mask_dataset
|
||||
cdef void *mask_band
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
mem_ds = InMemoryRaster(image)
|
||||
hband = mem_ds.band
|
||||
# copy numpy ndarray into an in-memory dataset
|
||||
image_dataset = _gdal.GDALCreate(
|
||||
memdriver,
|
||||
"image",
|
||||
image.shape[1],
|
||||
image.shape[0],
|
||||
1,
|
||||
<_gdal.GDALDataType>dtypes.dtype_rev[image.dtype.name],
|
||||
NULL)
|
||||
image_band = _gdal.GDALGetRasterBand(image_dataset, 1)
|
||||
_io.io_auto(image, image_band, True)
|
||||
elif isinstance(image, tuple):
|
||||
rdr = image.ds
|
||||
hband = rdr.band(image.bidx)
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError("Invalid source image")
|
||||
|
||||
|
||||
if isinstance(mask, np.ndarray):
|
||||
# A boolean mask must be converted to uint8 for GDAL
|
||||
mask_ds = InMemoryRaster(mask.astype('uint8'))
|
||||
hmaskband = mask_ds.band
|
||||
mask_cast = mask.astype('uint8')
|
||||
mask_dataset = _gdal.GDALCreate(
|
||||
memdriver,
|
||||
"mask",
|
||||
mask.shape[1],
|
||||
mask.shape[0],
|
||||
1,
|
||||
<_gdal.GDALDataType>dtypes.dtype_rev['uint8'],
|
||||
NULL)
|
||||
mask_band = _gdal.GDALGetRasterBand(mask_dataset, 1)
|
||||
_io.io_auto(mask_cast, mask_band, True)
|
||||
elif isinstance(mask, tuple):
|
||||
if mask.shape != image.shape:
|
||||
raise ValueError("Mask must have same shape as image")
|
||||
mrdr = mask.ds
|
||||
hmaskband = mrdr.band(mask.bidx)
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
elif mask is None:
|
||||
mask_band = NULL
|
||||
else:
|
||||
hmaskband = NULL
|
||||
|
||||
result = _gdal.GDALFillNodata(hband, hmaskband, max_search_distance, 0, smoothing_iterations, options, NULL, NULL)
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
_io.io_auto(image, hband, False)
|
||||
|
||||
if mem_ds is not None:
|
||||
mem_ds.close()
|
||||
if mask_ds is not None:
|
||||
mask_ds.close()
|
||||
|
||||
raise ValueError("Invalid source image mask")
|
||||
|
||||
with cpl_errs:
|
||||
_gdal.GDALFillNodata(image_band, mask_band, max_search_distance, 0, smoothing_iterations, NULL, NULL, NULL)
|
||||
|
||||
# read the result into a numpy ndarray
|
||||
result = np.empty(image.shape, dtype=image.dtype)
|
||||
_io.io_auto(result, image_band, False)
|
||||
|
||||
_gdal.GDALClose(image_dataset)
|
||||
_gdal.GDALClose(mask_dataset)
|
||||
|
||||
return result
|
||||
|
||||
@ -4,21 +4,42 @@ from rasterio._fill import _fillnodata
|
||||
def fillnodata(image, mask=None, max_search_distance=100.0,
|
||||
smoothing_iterations=0):
|
||||
"""
|
||||
Fill nodata pixels by interpolation from the edges
|
||||
Fill selected raster regions by interpolation from the edges.
|
||||
|
||||
This algorithm will interpolate values for all designated nodata pixels
|
||||
(marked by zeros in `mask`). For each pixel a four direction conic search
|
||||
is done to find values to interpolate from (using inverse distance
|
||||
weighting). Once all values are interpolated, zero or more smoothing
|
||||
iterations (3x3 average filters on interpolated pixels) are applied to
|
||||
smooth out artifacts.
|
||||
|
||||
This algorithm is generally suitable for interpolating missing regions of
|
||||
fairly continuously varying rasters (such as elevation models for
|
||||
instance). It is also suitable for filling small holes and cracks in more
|
||||
irregularly varying images (like aerial photos). It is generally not so
|
||||
great for interpolating a raster from sparse point data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : numpy ndarray
|
||||
The band to be modified in place
|
||||
mask : numpy ndarray
|
||||
A mask band indicating pixels to be interpolated (zero valud)
|
||||
A mask band indicating which pixels to interpolate. Pixels to
|
||||
interpolate into are indicated by the value 0. Values of 1 indicate
|
||||
areas to use during interpolation. Must be same shape as image.
|
||||
max_search_distance : float, optional
|
||||
The maxmimum number of pixels to search in all directions to find
|
||||
values to interpolate from. The default is 100.
|
||||
smoothing_iterations : integer, optional
|
||||
The number of 3x3 smoothing filter passes to run. The default is 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : numpy ndarray
|
||||
The interpolated raster array
|
||||
"""
|
||||
max_search_distance = float(max_search_distance)
|
||||
smoothing_iterations = int(smoothing_iterations)
|
||||
with rasterio.drivers():
|
||||
ret = _fillnodata(image, mask, max_search_distance, smoothing_iterations)
|
||||
return ret
|
||||
|
||||
@ -11,11 +11,35 @@ logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
def test_fillnodata():
|
||||
"""Test filling nodata values in an ndarray"""
|
||||
# create a 5x5 array, with some missing data
|
||||
a = numpy.ones([5, 5]) * 42
|
||||
a[1:4,1:4] = numpy.nan
|
||||
a = numpy.ones([3, 3]) * 42
|
||||
a[1][1] = 0
|
||||
# find the missing data
|
||||
mask = ~numpy.isnan(a)
|
||||
mask = ~(a == 0)
|
||||
# fill the missing data using interpolation from the edges
|
||||
ret = fillnodata(a, mask)
|
||||
assert(((numpy.ones([5, 5]) * 42) - a).sum() == 0)
|
||||
assert(ret is None) # inplace modification, should not return anything
|
||||
result = fillnodata(a, mask)
|
||||
assert(numpy.all((numpy.ones([3, 3]) * 42) == result))
|
||||
|
||||
def test_fillnodata_invalid_types():
|
||||
a = numpy.ones([3, 3])
|
||||
with pytest.raises(ValueError):
|
||||
fillnodata(None, a)
|
||||
with pytest.raises(ValueError):
|
||||
fillnodata(a, 42)
|
||||
|
||||
def test_fillnodata_mask_ones():
|
||||
# when mask is all ones, image should be unmodified
|
||||
a = numpy.ones([3, 3]) * 42
|
||||
a[1][1] = 0
|
||||
mask = numpy.ones([3, 3])
|
||||
result = fillnodata(a, mask)
|
||||
assert(numpy.all(a == result))
|
||||
|
||||
'''
|
||||
def test_fillnodata_smooth():
|
||||
a = numpy.array([[1,3,3,1],[2,0,0,2],[2,0,0,2],[1,3,3,1]], dtype=numpy.float64)
|
||||
mask = ~(a == 0)
|
||||
result = fillnodata(a, mask, max_search_distance=1, smoothing_iterations=0)
|
||||
assert(result[1][1] == 3)
|
||||
result = fillnodata(a, mask, max_search_distance=1, smoothing_iterations=1)
|
||||
assert(round(result[1][1], 1) == 2.2)
|
||||
'''
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user