rasterio/tests/test_dataset_mask.py
2018-06-05 17:16:11 -04:00

215 lines
6.4 KiB
Python

import logging
import sys
import numpy as np
import pytest
try:
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
except ImportError:
plt = None
from affine import Affine
import rasterio
from rasterio.enums import Resampling
from rasterio.errors import NodataShadowWarning
from rasterio.crs import CRS
from .conftest import requires_gdal2
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
# Setup test arrays
red = np.array([[0, 0, 0],
[0, 1, 1],
[1, 0, 1]]).astype('uint8') * 255
grn = np.array([[0, 0, 0],
[1, 0, 1],
[1, 0, 1]]).astype('uint8') * 255
blu = np.array([[0, 0, 0],
[1, 1, 0],
[1, 0, 1]]).astype('uint8') * 255
# equivalent to alp = red | grn | blu
# valid data anywhere there is at least one R, G or B value
alp = np.array([[0, 0, 0],
[1, 1, 1],
[1, 0, 1]]).astype('uint8') * 255
# mask might be constructed using different tools
# and differ from a strict interpretation of rgb values
msk = np.array([[0, 0, 0],
[1, 1, 1],
[1, 1, 1]]).astype('uint8') * 255
alldata = np.array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]).astype('uint8') * 255
# boundless window ((1, 4, (1, 4))
alp_shift_lr = np.array([[1, 1, 0],
[0, 1, 0],
[0, 0, 0]]).astype('uint8') * 255
# whole mask resampled to (1, 5, 5) array
resampmask = np.array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 0, 1, 1],
[1, 1, 0, 1, 1]]).astype('uint8') * 255
# whole mask resampled to (1, 5, 5) array
resampave = np.array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 0, 1, 1]]).astype('uint8') * 255
@pytest.fixture(scope='function')
def tiffs(tmpdir):
_profile = {
'transform': Affine(5.0, 0.0, 0.0, 0.0, -5.0, 0.0),
'crs': CRS({'init': 'epsg:4326'}),
'driver': 'GTiff',
'dtype': 'uint8',
'height': 3,
'width': 3}
# 1. RGB without nodata value
prof = _profile.copy()
prof['count'] = 3
prof['nodata'] = None
with rasterio.open(str(tmpdir.join('rgb_no_ndv.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
# 2. RGB with nodata value
prof = _profile.copy()
prof['count'] = 3
prof['nodata'] = 0
with rasterio.open(str(tmpdir.join('rgb_ndv.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
# 3. RGBA without nodata value
prof = _profile.copy()
prof['count'] = 4
prof['nodata'] = None
with rasterio.open(str(tmpdir.join('rgba_no_ndv.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
dst.write(alp, 4)
# 4. RGBA with nodata value
prof = _profile.copy()
prof['count'] = 4
prof['nodata'] = 0
with rasterio.open(str(tmpdir.join('rgba_ndv.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
dst.write(alp, 4)
# 5. RGB with msk
prof = _profile.copy()
prof['count'] = 3
with rasterio.open(str(tmpdir.join('rgb_msk.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
dst.write_mask(msk)
# 6. RGB with msk (internal)
prof = _profile.copy()
prof['count'] = 3
with rasterio.Env(GDAL_TIFF_INTERNAL_MASK=True) as env:
with rasterio.open(str(tmpdir.join('rgb_msk_internal.tif')),
'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
dst.write_mask(msk)
# 7. RGBA with msk
prof = _profile.copy()
prof['count'] = 4
with rasterio.open(str(tmpdir.join('rgba_msk.tif')), 'w', **prof) as dst:
dst.write(red, 1)
dst.write(grn, 2)
dst.write(blu, 3)
dst.write(alp, 4)
dst.write_mask(msk)
return tmpdir
def test_no_ndv(tiffs):
with rasterio.open(str(tiffs.join('rgb_no_ndv.tif'))) as src:
assert np.array_equal(src.dataset_mask(), alldata)
def test_rgb_ndv(tiffs):
with rasterio.open(str(tiffs.join('rgb_ndv.tif'))) as src:
assert np.array_equal(src.dataset_mask(), alp)
def test_rgba_no_ndv(tiffs):
with rasterio.open(str(tiffs.join('rgba_no_ndv.tif'))) as src:
assert np.array_equal(src.dataset_mask(), alp)
def test_rgba_ndv(tiffs):
with rasterio.open(str(tiffs.join('rgba_ndv.tif'))) as src:
with pytest.warns(NodataShadowWarning):
res = src.dataset_mask()
assert np.array_equal(res, alp)
def test_rgb_msk(tiffs):
with rasterio.open(str(tiffs.join('rgb_msk.tif'))) as src:
assert np.array_equal(src.dataset_mask(), msk)
# each band's mask is also equal
for bmask in src.read_masks():
assert np.array_equal(bmask, msk)
def test_rgb_msk_int(tiffs):
with rasterio.open(str(tiffs.join('rgb_msk_internal.tif'))) as src:
assert np.array_equal(src.dataset_mask(), msk)
def test_rgba_msk(tiffs):
with rasterio.open(str(tiffs.join('rgba_msk.tif'))) as src:
# mask takes precendent over alpha
assert np.array_equal(src.dataset_mask(), msk)
def test_kwargs(tiffs):
with rasterio.open(str(tiffs.join('rgb_ndv.tif'))) as src:
# window and boundless are passed along
other = src.dataset_mask(window=((1, 4), (1, 4)), boundless=True)
assert np.array_equal(alp_shift_lr, other)
other = src.dataset_mask(out_shape=(1, 5, 5))
assert np.array_equal(resampmask, other)
out = np.zeros((1, 5, 5), dtype=np.uint8)
other = src.dataset_mask(out=out)
assert np.array_equal(resampmask, other)
# band indexes are not supported
with pytest.raises(TypeError):
src.dataset_mask(indexes=1)
@requires_gdal2(reason="GDAL 2+ required for resampling")
def test_kwargs_resampling(tiffs):
with rasterio.open(str(tiffs.join('rgb_ndv.tif'))) as src:
other = src.dataset_mask(out_shape=(1, 5, 5), resampling=Resampling.bilinear) != 0
other = other.astype(np.uint8) * 255
assert np.array_equal(resampave, other)