Basic read and write of GeoTIFFs with plenty of tests.

And an example script, which also appears in the readme.
This commit is contained in:
Sean Gillies 2013-11-06 22:56:44 -07:00
parent a0c962b6d4
commit 5bcda751d0
11 changed files with 688 additions and 40 deletions

View File

@ -14,19 +14,30 @@ Example
Here's an example of the features fasterio aims to provide.
import numpy
import rasterio
# Read raster bands directly into provided Numpy arrays.
with rasterio.open('reflectance.tif') as src:
vis = src.read_band(2, numpy.zeros((src.shape), numpy.float))
nir = src.read_band(3, numpy.zeros((src.shape), numpy.float))
import subprocess
# Read raster bands directly to Numpy arrays.
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src:
r = src.read_band(0).astype(rasterio.float32)
g = src.read_band(1).astype(rasterio.float32)
b = src.read_band(2).astype(rasterio.float32)
# Combine arrays using the 'add' ufunc and then convert back to btyes.
total = (r + g + b)/3.0
total = total.astype(rasterio.ubyte)
ndvi = (nir-vis)/(nir+vis)
# Write the product as a raster band to a new file.
with rasterio.open('ndvi.tif', 'w') as dst:
dst.append_band(ndvi)
with rasterio.open(
'/tmp/total.tif', 'w',
driver='GTiff',
width=src.width, height=src.height, count=1,
crs=src.crs, transform=src.transform,
dtype=total.dtype) as dst:
dst.write_band(0, total)
info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif'])
print(info)
Dependencies
------------

26
examples/total.py Normal file
View File

@ -0,0 +1,26 @@
import rasterio
# Read raster bands directly to Numpy arrays.
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src:
r = src.read_band(0).astype(rasterio.float32)
g = src.read_band(1).astype(rasterio.float32)
b = src.read_band(2).astype(rasterio.float32)
# Combine arrays using the 'add' ufunc and then convert back to btyes.
total = (r + g + b)/3.0
total = total.astype(rasterio.ubyte)
# Write the product as a raster band to a new file.
with rasterio.open(
'/tmp/total.tif', 'w',
driver='GTiff',
width=src.width, height=src.height, count=1,
crs=src.crs, transform=src.transform,
dtype=total.dtype) as dst:
dst.write_band(0, total)
import subprocess
info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif'])
print(info)

View File

@ -4,28 +4,75 @@ import os
from six import string_types
from rasterio._io import RasterReader, RasterUpdateSession
from rasterio._io import RasterReader, RasterUpdater
import rasterio.dtypes
from rasterio.dtypes import (
ubyte, uint8, uint16, int16, uint32, int32, float32, float64)
def open(path, mode='r', driver=None):
"""."""
def open(
path, mode='r',
driver=None,
width=None, height=None,
count=None,
dtype=None,
crs=None, transform=None):
"""Open file at ``path`` in ``mode`` "r" (read), "r+" (read/write),
or "w" (write) and return a ``Reader`` or ``Updater`` object.
In write mode, a driver name such as "GTiff" or "JPEG" (see GDAL
docs or ``gdan_translate --help`` on the command line), ``width``
(number of pixels per line) and ``height`` (number of lines), the
``count`` number of bands in the new file must be specified.
Additionally, the data type for bands such as ``rasterio.ubyte`` for
8-bit bands or ``rasterio.uint16`` for 16-bit bands must be
specified using the ``dtype`` argument.
A coordinate reference system for raster datasets in write mode can
be defined by the ``crs`` argument. It takes Proj4 style mappings
like
{'proj': 'longlat', 'ellps': 'WGS84', 'datum': 'WGS84',
'no_defs': True}
A geo-transform matrix that maps pixel coordinates to coordinates in
the specified crs should be specified using the ``transform``
argument. This matrix is represented by a six-element sequence.
Item 0: the top left x value
Item 1: W-E pixel resolution
Item 2: rotation, 0 if the image is "north up"
Item 3: top left y value
Item 4: rotation, 0 if the image is "north up"
Item 5: N-S pixel resolution (usually a negative number)
"""
if not isinstance(path, string_types):
raise TypeError("invalid path: %r" % path)
if mode and not isinstance(mode, string_types):
raise TypeError("invalid mode: %r" % mode)
if driver and not isinstance(driver, string_types):
raise TypeError("invalid driver: %r" % driver)
if mode in ('a', 'r'):
if mode in ('r', 'r+'):
if not os.path.exists(path):
raise IOError("no such file or directory: %r" % path)
if mode == 'a':
s = RasterUpdateSession(path, mode, driver=None)
elif mode == 'r':
if mode == 'r':
s = RasterReader(path)
elif mode == 'r+':
raise NotImplemented("r+ mode not implemented")
# s = RasterUpdater(path, mode, driver=None)
elif mode == 'w':
s = RasterUpdateSession(path, mode, driver=driver)
s = RasterUpdater(
path, mode, driver,
width, height, count,
crs, transform, dtype)
else:
raise ValueError(
"mode string must be one of 'r', 'w', or 'a', not %s" % mode)
"mode string must be one of 'r', 'r+', or 'w', not %s" % mode)
s.start()
return s
def check_dtype(dt):
tp = getattr(dt, 'type', dt)
return tp in rasterio.dtypes.dtype_rev

View File

@ -1,14 +1,46 @@
# GDAL function definitions.
#
cdef extern from "cpl_conv.h":
void CPLFree (void *ptr)
void CPLSetThreadLocalConfigOption (char *key, char *val)
cdef extern from "cpl_string.h":
char ** CSLSetNameValue (char **list, char *name, char *value)
void CSLDestroy (char **list)
cdef extern from "ogr_srs_api.h":
void OSRCleanup ()
void * OSRClone (void *srs)
void OSRDestroySpatialReference (void *srs)
int OSRExportToProj4 (void *srs, char **params)
int OSRExportToWkt (void *srs, char **params)
int OSRImportFromProj4 (void *srs, char *proj)
void * OSRNewSpatialReference (char *wkt)
void OSRRelease (void *srs)
cdef extern from "gdal.h":
void GDALAllRegister()
void * GDALGetDriverByName(const char *name)
void * GDALOpen(const char *filename, int access)
void GDALClose(void *ds)
void * GDALGetDatasetDriver(void *ds)
int GDALGetGeoTransform (void *ds, double *transform)
const char * GDALGetProjectionRef(void *ds)
int GDALGetRasterXSize(void *ds)
int GDALGetRasterYSize(void *ds)
int GDALGetRasterCount(void *ds)
void * GDALGetRasterBand(void *ds, int num)
int GDALSetGeoTransform (void *ds, double *transform)
int GDALSetProjection(void *ds, const char *wkt)
int GDALGetRasterDataType(void *band)
int GDALRasterIO(void *band, int access, int xoff, int yoff, int xsize, int ysize, void *buffer, int width, int height, int data_type, int poff, int loff)
void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data)
void * GDALGetDriverByName(const char *name)
void * GDALCreate(void *driver, const char *filename, int width, int height, int nbands, int dtype, const char **options)
void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data)
const char * GDALGetDriverShortName(void *driver)
const char * GDALGetDriverLongName(void *driver)

View File

@ -1,10 +1,25 @@
import logging
import os
import os.path
import numpy as np
cimport numpy as np
ctypedef np.uint8_t DTYPE_UBYTE_t
ctypedef np.uint16_t DTYPE_UINT16_t
ctypedef np.int16_t DTYPE_INT16_t
ctypedef np.uint32_t DTYPE_UINT32_t
ctypedef np.int32_t DTYPE_INT32_t
ctypedef np.float32_t DTYPE_FLOAT32_t
ctypedef np.float64_t DTYPE_FLOAT64_t
from rasterio cimport _gdal
from rasterio import dtypes
log = logging.getLogger('rasterio')
class NullHandler(logging.Handler):
def emit(self, record):
pass
log.addHandler(NullHandler())
cdef int registered = 0
@ -12,21 +27,100 @@ cdef void register():
_gdal.GDALAllRegister()
registered = 1
cdef int io_ubyte(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_UBYTE_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0, 0], width, height, 1, 0, 0)
cdef int io_uint16(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_UINT16_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 2, 0, 0)
cdef int io_int16(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_INT16_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 3, 0, 0)
cdef int io_uint32(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_UINT32_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 4, 0, 0)
cdef int io_int32(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_INT32_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 5, 0, 0)
cdef int io_float32(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_FLOAT32_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 6, 0, 0)
cdef int io_float64(
void *hband,
int mode,
int width,
int height,
np.ndarray[DTYPE_FLOAT64_t, ndim=2, mode='c'] buffer):
return _gdal.GDALRasterIO(
hband, mode, 0, 0, width, height,
&buffer[0,0], width, height, 7, 0, 0)
cdef class RasterReader:
# Read-only access to raster data and metadata.
cdef void *_hds
cdef int _count
cdef readonly object name
cdef readonly object mode
cdef readonly object width, height
cdef readonly object shape
cdef public object driver
cdef public object _dtypes
cdef public object _closed
cdef public object _crs
cdef public object _transform
def __cinit__(self, path):
def __init__(self, path):
self.name = path
self.mode = 'r'
self._hds = NULL
self._count = 0
self._closed = True
self._dtypes = []
def __dealloc__(self):
self.stop()
@ -42,12 +136,75 @@ cdef class RasterReader:
register()
cdef const char *fname = self.name
self._hds = _gdal.GDALOpen(fname, 0)
if not self._hds:
raise ValueError("Null dataset")
cdef void *drv
cdef const char *drv_name
drv = _gdal.GDALGetDatasetDriver(self._hds)
drv_name = _gdal.GDALGetDriverShortName(drv)
self.driver = drv_name
self._count = _gdal.GDALGetRasterCount(self._hds)
self.width = _gdal.GDALGetRasterXSize(self._hds)
self.height = _gdal.GDALGetRasterYSize(self._hds)
self.shape = (self.height, self.width)
self._transform = self.read_transform()
self._crs = self.read_crs()
self._closed = False
def read_crs(self):
cdef char *proj_c = NULL
if self._hds is NULL:
raise ValueError("Null dataset")
#cdef const char *wkt = _gdal.GDALGetProjectionRef(self._hds)
cdef void *osr = _gdal.OSRNewSpatialReference(
_gdal.GDALGetProjectionRef(self._hds))
log.debug("Got coordinate system")
crs = {}
if osr is not NULL:
_gdal.OSRExportToProj4(osr, &proj_c)
if proj_c is NULL:
raise ValueError("Null projection")
proj_b = proj_c
log.debug("Params: %s", proj_b)
value = proj_b.decode()
value = value.strip()
for param in value.split():
kv = param.split("=")
if len(kv) == 2:
k, v = kv
try:
v = float(v)
if v % 1 == 0:
v = int(v)
except ValueError:
# Leave v as a string
pass
elif len(kv) == 1:
k, v = kv[0], True
else:
raise ValueError("Unexpected proj parameter %s" % param)
k = k.lstrip("+")
crs[k] = v
_gdal.CPLFree(proj_c)
_gdal.OSRDestroySpatialReference(osr)
else:
log.debug("Projection not found (cogr_crs was NULL)")
return crs
def read_transform(self):
if self._hds is NULL:
raise ValueError("Null dataset")
cdef double gt[6]
_gdal.GDALGetGeoTransform(self._hds, gt)
transform = [0]*6
for i in range(6):
transform[i] = gt[i]
return transform
def stop(self):
if self._hds is not NULL:
_gdal.GDALClose(self._hds)
@ -57,35 +214,224 @@ cdef class RasterReader:
self.stop()
self._closed = True
@property
def count(self):
if self._hds is not NULL:
self._count = _gdal.GDALGetRasterCount(self._hds)
return self._count
@property
def closed(self):
return self._closed
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
@property
def closed(self):
return self._closed
@property
def count(self):
if not self._count:
if not self._hds:
raise ValueError("Can't read closed raster file")
self._count = _gdal.GDALGetRasterCount(self._hds)
return self._count
@property
def dtypes(self):
"""Returns an ordered list of all band data types."""
cdef void *hband = NULL
if not self._dtypes:
if not self._hds:
raise ValueError("Can't read closed raster file")
for i in range(self._count):
hband = _gdal.GDALGetRasterBand(self._hds, i+1)
self._dtypes.append(
dtypes.dtype_fwd[_gdal.GDALGetRasterDataType(hband)])
return self._dtypes
def get_crs(self):
if not self._crs:
self._crs = self.read_crs()
return self._crs
crs = property(get_crs)
def get_transform(self):
if not self._transform:
self._transform = self.read_transform()
return self._transform
transform = property(get_transform)
def read_band(self, i, out=None):
"""Read the ith band into an `out` array if provided, otherwise
return a new array."""
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
if not self._hds:
raise ValueError("Can't read closed raster file")
if out is not None and out.dtype != self.dtypes[i]:
raise ValueError("Band and output array dtypes do not match")
if out is not None and out.shape != self.shape:
raise ValueError("Band and output shapes do not match")
dtype = self.dtypes[i]
if out is None:
out = np.zeros(self.shape, np.ubyte)
cdef np.ndarray[DTYPE_UBYTE_t, ndim=2, mode="c"] im = out
_gdal.GDALRasterIO(
hband, 0, 0, 0, self.width, self.height,
&im[0, 0], self.width, self.height, 1, 0, 0)
out = np.zeros(self.shape, dtype)
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
if dtype == dtypes.ubyte:
retval = io_ubyte(hband, 0, self.width, self.height, out)
elif dtype == dtypes.uint16:
retval = io_uint16(hband, 0, self.width, self.height, out)
elif dtype == dtypes.int16:
retval = io_int16(hband, 0, self.width, self.height, out)
elif dtype == dtypes.uint32:
retval = io_uint32(hband, 0, self.width, self.height, out)
elif dtype == dtypes.int32:
retval = io_int32(hband, 0, self.width, self.height, out)
elif dtype == dtypes.float32:
retval = io_float32(hband, 0, self.width, self.height, out)
elif dtype == dtypes.float64:
retval = io_float64(hband, 0, self.width, self.height, out)
else:
raise ValueError("Invalid dtype")
# TODO: handle errors (by retval).
return out
cdef class RasterUpdateSession:
pass
cdef class RasterUpdater(RasterReader):
# Read-write access to raster data and metadata.
# TODO: the r+ mode.
cdef readonly object _init_dtype
def __init__(
self, path, mode, driver=None,
width=None, height=None, count=None,
crs=None, transform=None, dtype=None):
self.name = path
self.mode = mode
self.driver = driver
self.width = width
self.height = height
self._count = count
self._init_dtype = dtype
self._hds = NULL
self._count = count
self._crs = crs
self._transform = transform
self._closed = True
self._dtypes = []
def __repr__(self):
return "<%s RasterUpdater '%s' at %s>" % (
self.closed and 'closed' or 'open',
self.name,
hex(id(self)))
def start(self):
cdef const char *drv_name
cdef void *drv
if not registered:
register()
cdef const char *fname = self.name
if self.mode == 'w':
# Delete existing file, create.
if os.path.exists(self.name):
os.unlink(self.name)
drv_name = self.driver
drv = _gdal.GDALGetDriverByName(drv_name)
# Find the equivalent GDAL data type or raise an exception
# We've mapped numpy scalar types to GDAL types so see
# if we can crosswalk those.
if hasattr(self._init_dtype, 'type'):
tp = self._init_dtype.type
if tp not in dtypes.dtype_rev:
raise ValueError(
"Unsupported dtype: %s" % self._init_dtype)
else:
gdal_dtype = dtypes.dtype_rev.get(tp)
else:
gdal_dtype = dtypes.dtype_rev.get(self._init_dtype)
self._hds = _gdal.GDALCreate(
drv, fname, self.width, self.height, self._count,
gdal_dtype,
NULL)
if self._transform:
self.write_transform(self._transform)
if self._crs:
self.write_crs(self._crs)
elif self.mode == 'a':
self._hds = _gdal.GDALOpen(fname, 1)
self._count = _gdal.GDALGetRasterCount(self._hds)
self.width = _gdal.GDALGetRasterXSize(self._hds)
self.height = _gdal.GDALGetRasterYSize(self._hds)
self.shape = (self.height, self.width)
self._closed = False
def get_crs(self):
if not self._crs:
self._crs = self.read_crs()
return self._crs
def write_crs(self, crs):
if self._hds is NULL:
raise ValueError("Can't read closed raster file")
cdef void *osr = _gdal.OSRNewSpatialReference(NULL)
if osr is NULL:
raise ValueError("Null spatial reference")
params = []
for k, v in crs.items():
if v is True or (k == 'no_defs' and v):
params.append("+%s" % k)
else:
params.append("+%s=%s" % (k, v))
proj = " ".join(params)
proj_b = proj.encode()
cdef const char *proj_c = proj_b
_gdal.OSRImportFromProj4(osr, proj_c)
cdef char *wkt
_gdal.OSRExportToWkt(osr, &wkt)
_gdal.GDALSetProjection(self._hds, wkt)
_gdal.CPLFree(wkt)
_gdal.OSRDestroySpatialReference(osr)
self._crs = crs
crs = property(get_crs, write_crs)
def write_transform(self, transform):
if self._hds is NULL:
raise ValueError("Can't read closed raster file")
cdef double gt[6]
for i in range(6):
gt[i] = transform[i]
retval = _gdal.GDALSetGeoTransform(self._hds, gt)
self._transform = transform
def get_transform(self):
if not self._transform:
self._transform = self.read_transform()
return self._transform
transform = property(get_transform, write_transform)
def write_band(self, i, src):
"""Write the src array into the ith band."""
if not self._hds:
raise ValueError("Can't read closed raster file")
if src is not None and src.dtype != self.dtypes[i]:
raise ValueError("Band and srcput array dtypes do not match")
if src is not None and src.shape != self.shape:
raise ValueError("Band and srcput shapes do not match")
dtype = self.dtypes[i]
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
if dtype == dtypes.ubyte:
retval = io_ubyte(hband, 1, self.width, self.height, src)
elif dtype == dtypes.uint16:
retval = io_uint16(hband, 1, self.width, self.height, src)
elif dtype == dtypes.int16:
retval = io_int16(hband, 1, self.width, self.height, src)
elif dtype == dtypes.uint32:
retval = io_uint32(hband, 1, self.width, self.height, src)
elif dtype == dtypes.int32:
retval = io_int32(hband, 1, self.width, self.height, src)
elif dtype == dtypes.float32:
retval = io_float32(hband, 1, self.width, self.height, src)
elif dtype == dtypes.float64:
retval = io_float64(hband, 1, self.width, self.height, src)
else:
raise ValueError("Invalid dtype")
# TODO: handle errors (by retval).

27
rasterio/dtypes.py Normal file
View File

@ -0,0 +1,27 @@
import numpy
ubyte = uint8 = numpy.uint8
uint16 = numpy.uint16
int16 = numpy.int16
uint32 = numpy.uint32
int32 = numpy.int32
float32 = numpy.float32
float64 = numpy.float64
# Not supported:
# GDT_CInt16 = 8, GDT_CInt32 = 9, GDT_CFloat32 = 10, GDT_CFloat64 = 11
dtype_fwd = {
0: None, # GDT_Unknown
1: ubyte, # GDT_Byte
2: uint16, # GDT_UInt16
3: int16, # GDT_Int16
4: uint32, # GDT_UInt32
5: int32, # GDT_Int32
6: float32, # GDT_Float32
7: float64 } # GDT_Float64
dtype_rev = {v: k for k, v in dtype_fwd.items()}
dtype_rev[uint8] = 1

View File

@ -0,0 +1 @@
#

View File

@ -0,0 +1,26 @@
<PAMDataset>
<PAMRasterBand band="1">
<Metadata>
<MDI key="STATISTICS_MAXIMUM">255</MDI>
<MDI key="STATISTICS_MEAN">29.947726688477</MDI>
<MDI key="STATISTICS_MINIMUM">0</MDI>
<MDI key="STATISTICS_STDDEV">52.340921626611</MDI>
</Metadata>
</PAMRasterBand>
<PAMRasterBand band="2">
<Metadata>
<MDI key="STATISTICS_MAXIMUM">255</MDI>
<MDI key="STATISTICS_MEAN">44.516147889382</MDI>
<MDI key="STATISTICS_MINIMUM">0</MDI>
<MDI key="STATISTICS_STDDEV">56.934318291876</MDI>
</Metadata>
</PAMRasterBand>
<PAMRasterBand band="3">
<Metadata>
<MDI key="STATISTICS_MAXIMUM">255</MDI>
<MDI key="STATISTICS_MEAN">48.113056354743</MDI>
<MDI key="STATISTICS_MINIMUM">0</MDI>
<MDI key="STATISTICS_STDDEV">60.112778509941</MDI>
</Metadata>
</PAMRasterBand>
</PAMDataset>

View File

@ -0,0 +1,9 @@
import numpy
import rasterio
def test_np_dt_uint8():
assert rasterio.check_dtype(numpy.dtype(numpy.uint8))
def test_dt_ubyte():
assert rasterio.check_dtype(numpy.dtype(rasterio.ubyte))

View File

@ -1,16 +1,26 @@
import unittest
import numpy
import rasterio
class ReaderContextTest(unittest.TestCase):
def test_context(self):
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
self.assertEqual(s.name, 'rasterio/tests/data/RGB.byte.tif')
self.assertEqual(s.driver, 'GTiff')
self.assertEqual(s.closed, False)
self.assertEqual(s.count, 3)
self.assertEqual(s.width, 791)
self.assertEqual(s.height, 718)
self.assertEqual(s.shape, (718, 791))
self.assertEqual(s.dtypes, [rasterio.ubyte]*3)
self.assertEqual(s.crs['proj'], 'utm')
self.assertEqual(s.crs['zone'], 18)
self.assertEqual(
s.transform,
[101985.0, 300.0379266750948, 0.0,
2826915.0, 0.0, -300.041782729805])
self.assertEqual(
repr(s),
"<open RasterReader 'rasterio/tests/data/RGB.byte.tif' "
@ -20,8 +30,32 @@ class ReaderContextTest(unittest.TestCase):
self.assertEqual(s.width, 791)
self.assertEqual(s.height, 718)
self.assertEqual(s.shape, (718, 791))
self.assertEqual(s.dtypes, [rasterio.ubyte]*3)
self.assertEqual(s.crs['proj'], 'utm')
self.assertEqual(s.crs['zone'], 18)
self.assertEqual(
s.transform,
[101985.0, 300.0379266750948, 0.0,
2826915.0, 0.0, -300.041782729805])
self.assertEqual(
repr(s),
"<closed RasterReader 'rasterio/tests/data/RGB.byte.tif' "
"at %s>" % hex(id(s)))
def test_read_ubyte(self):
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
a = s.read_band(0)
self.assertEqual(a.dtype, rasterio.ubyte)
def test_read_ubyte_out(self):
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
a = numpy.zeros((718, 791), dtype=rasterio.ubyte)
a = s.read_band(0, a)
self.assertEqual(a.dtype, rasterio.ubyte)
def test_read_out_dtype_fail(self):
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
a = numpy.zeros((718, 791), dtype=rasterio.float32)
self.assertRaises(ValueError, s.read_band, 0, a)
def test_read_out_shape_fail(self):
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
a = numpy.zeros((42, 42), dtype=rasterio.ubyte)
self.assertRaises(ValueError, s.read_band, 0, a)

View File

@ -0,0 +1,89 @@
import os.path
import unittest
import shutil
import subprocess
import tempfile
import numpy
import rasterio
class WriterContextTest(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tempdir)
def test_context(self):
name = os.path.join(self.tempdir, "test_context.tif")
with rasterio.open(
name, 'w',
driver='GTiff', width=100, height=100, count=1,
dtype=rasterio.ubyte) as s:
self.assertEqual(s.name, name)
self.assertEqual(s.driver, 'GTiff')
self.assertEqual(s.closed, False)
self.assertEqual(s.count, 1)
self.assertEqual(s.width, 100)
self.assertEqual(s.height, 100)
self.assertEqual(s.shape, (100, 100))
self.assertEqual(
repr(s),
"<open RasterUpdater '%s' at %s>" % (name, hex(id(s))))
self.assertEqual(s.closed, True)
self.assertEqual(s.count, 1)
self.assertEqual(s.width, 100)
self.assertEqual(s.height, 100)
self.assertEqual(s.shape, (100, 100))
self.assertEqual(
repr(s),
"<closed RasterUpdater '%s' at %s>" % (name, hex(id(s))))
info = subprocess.check_output(["gdalinfo", name])
self.assert_("GTiff" in info)
self.assert_(
"Size is 100, 100" in info)
self.assert_(
"Band 1 Block=100x81 Type=Byte, ColorInterp=Gray" in info)
def test_write_ubyte(self):
name = os.path.join(self.tempdir, "test_write_ubyte.tif")
a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127
with rasterio.open(
name, 'w',
driver='GTiff', width=100, height=100, count=1,
dtype=a.dtype) as s:
s.write_band(0, a)
info = subprocess.check_output(["gdalinfo", "-stats", name])
self.assert_(
"Minimum=127.000, Maximum=127.000, "
"Mean=127.000, StdDev=0.000" in info,
info)
def test_write_float(self):
name = os.path.join(self.tempdir, "test_write_float.tif")
a = numpy.ones((100, 100), dtype=rasterio.float32) * 42.0
with rasterio.open(
name, 'w',
driver='GTiff', width=100, height=100, count=2,
dtype=rasterio.float32) as s:
self.assertEqual(s.dtypes, [rasterio.float32]*2)
s.write_band(0, a)
s.write_band(1, a)
info = subprocess.check_output(["gdalinfo", "-stats", name])
self.assert_(
"Minimum=42.000, Maximum=42.000, "
"Mean=42.000, StdDev=0.000" in info,
info)
def test_write_crs_transform(self):
name = os.path.join(self.tempdir, "test_write_crs_transform.tif")
a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127
with rasterio.open(
name, 'w',
driver='GTiff', width=100, height=100, count=1,
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18},
transform=[101985.0, 300.0379266750948, 0.0,
2826915.0, 0.0, -300.041782729805],
dtype=rasterio.ubyte) as s:
s.write_band(0, a)
info = subprocess.check_output(["gdalinfo", name])
self.assert_('PROJCS["UTM Zone 18, Northern Hemisphere",' in info)
self.assert_("(300.037926675094809,-300.041782729804993)" in info)