From 5bcda751d0cca19c2916ddcc7cf18e6a87e832c2 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Wed, 6 Nov 2013 22:56:44 -0700 Subject: [PATCH] Basic read and write of GeoTIFFs with plenty of tests. And an example script, which also appears in the readme. --- README.md | 31 +- examples/total.py | 26 ++ rasterio/__init__.py | 65 +++- rasterio/_gdal.pxd | 36 ++- rasterio/_io.pyx | 384 +++++++++++++++++++++-- rasterio/dtypes.py | 27 ++ rasterio/tests/__init__.py | 1 + rasterio/tests/data/RGB.byte.tif.aux.xml | 26 ++ rasterio/tests/test_dtypes.py | 9 + rasterio/tests/test_read.py | 34 ++ rasterio/tests/test_write.py | 89 ++++++ 11 files changed, 688 insertions(+), 40 deletions(-) create mode 100644 examples/total.py create mode 100644 rasterio/dtypes.py create mode 100644 rasterio/tests/__init__.py create mode 100644 rasterio/tests/data/RGB.byte.tif.aux.xml create mode 100644 rasterio/tests/test_dtypes.py create mode 100644 rasterio/tests/test_write.py diff --git a/README.md b/README.md index 43c7e153..3aefedf7 100644 --- a/README.md +++ b/README.md @@ -14,19 +14,30 @@ Example Here's an example of the features fasterio aims to provide. - import numpy import rasterio - - # Read raster bands directly into provided Numpy arrays. - with rasterio.open('reflectance.tif') as src: - vis = src.read_band(2, numpy.zeros((src.shape), numpy.float)) - nir = src.read_band(3, numpy.zeros((src.shape), numpy.float)) + import subprocess + + # Read raster bands directly to Numpy arrays. + with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src: + r = src.read_band(0).astype(rasterio.float32) + g = src.read_band(1).astype(rasterio.float32) + b = src.read_band(2).astype(rasterio.float32) + + # Combine arrays using the 'add' ufunc and then convert back to btyes. + total = (r + g + b)/3.0 + total = total.astype(rasterio.ubyte) - ndvi = (nir-vis)/(nir+vis) - # Write the product as a raster band to a new file. - with rasterio.open('ndvi.tif', 'w') as dst: - dst.append_band(ndvi) + with rasterio.open( + '/tmp/total.tif', 'w', + driver='GTiff', + width=src.width, height=src.height, count=1, + crs=src.crs, transform=src.transform, + dtype=total.dtype) as dst: + dst.write_band(0, total) + + info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif']) + print(info) Dependencies ------------ diff --git a/examples/total.py b/examples/total.py new file mode 100644 index 00000000..39c864a9 --- /dev/null +++ b/examples/total.py @@ -0,0 +1,26 @@ +import rasterio + +# Read raster bands directly to Numpy arrays. +with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src: + r = src.read_band(0).astype(rasterio.float32) + g = src.read_band(1).astype(rasterio.float32) + b = src.read_band(2).astype(rasterio.float32) + +# Combine arrays using the 'add' ufunc and then convert back to btyes. +total = (r + g + b)/3.0 +total = total.astype(rasterio.ubyte) + +# Write the product as a raster band to a new file. +with rasterio.open( + '/tmp/total.tif', 'w', + driver='GTiff', + width=src.width, height=src.height, count=1, + crs=src.crs, transform=src.transform, + dtype=total.dtype) as dst: + dst.write_band(0, total) + +import subprocess +info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif']) +print(info) + + diff --git a/rasterio/__init__.py b/rasterio/__init__.py index e6ec91f1..85f8eda2 100644 --- a/rasterio/__init__.py +++ b/rasterio/__init__.py @@ -4,28 +4,75 @@ import os from six import string_types -from rasterio._io import RasterReader, RasterUpdateSession +from rasterio._io import RasterReader, RasterUpdater +import rasterio.dtypes +from rasterio.dtypes import ( + ubyte, uint8, uint16, int16, uint32, int32, float32, float64) -def open(path, mode='r', driver=None): - """.""" +def open( + path, mode='r', + driver=None, + width=None, height=None, + count=None, + dtype=None, + crs=None, transform=None): + """Open file at ``path`` in ``mode`` "r" (read), "r+" (read/write), + or "w" (write) and return a ``Reader`` or ``Updater`` object. + + In write mode, a driver name such as "GTiff" or "JPEG" (see GDAL + docs or ``gdan_translate --help`` on the command line), ``width`` + (number of pixels per line) and ``height`` (number of lines), the + ``count`` number of bands in the new file must be specified. + Additionally, the data type for bands such as ``rasterio.ubyte`` for + 8-bit bands or ``rasterio.uint16`` for 16-bit bands must be + specified using the ``dtype`` argument. + + A coordinate reference system for raster datasets in write mode can + be defined by the ``crs`` argument. It takes Proj4 style mappings + like + + {'proj': 'longlat', 'ellps': 'WGS84', 'datum': 'WGS84', + 'no_defs': True} + + A geo-transform matrix that maps pixel coordinates to coordinates in + the specified crs should be specified using the ``transform`` + argument. This matrix is represented by a six-element sequence. + + Item 0: the top left x value + Item 1: W-E pixel resolution + Item 2: rotation, 0 if the image is "north up" + Item 3: top left y value + Item 4: rotation, 0 if the image is "north up" + Item 5: N-S pixel resolution (usually a negative number) + """ if not isinstance(path, string_types): raise TypeError("invalid path: %r" % path) if mode and not isinstance(mode, string_types): raise TypeError("invalid mode: %r" % mode) if driver and not isinstance(driver, string_types): raise TypeError("invalid driver: %r" % driver) - if mode in ('a', 'r'): + if mode in ('r', 'r+'): if not os.path.exists(path): raise IOError("no such file or directory: %r" % path) - if mode == 'a': - s = RasterUpdateSession(path, mode, driver=None) - elif mode == 'r': + + if mode == 'r': s = RasterReader(path) + elif mode == 'r+': + raise NotImplemented("r+ mode not implemented") + # s = RasterUpdater(path, mode, driver=None) elif mode == 'w': - s = RasterUpdateSession(path, mode, driver=driver) + s = RasterUpdater( + path, mode, driver, + width, height, count, + crs, transform, dtype) else: raise ValueError( - "mode string must be one of 'r', 'w', or 'a', not %s" % mode) + "mode string must be one of 'r', 'r+', or 'w', not %s" % mode) + s.start() return s +def check_dtype(dt): + tp = getattr(dt, 'type', dt) + return tp in rasterio.dtypes.dtype_rev + diff --git a/rasterio/_gdal.pxd b/rasterio/_gdal.pxd index 5abfb92a..d2ca695d 100644 --- a/rasterio/_gdal.pxd +++ b/rasterio/_gdal.pxd @@ -1,14 +1,46 @@ # GDAL function definitions. # + +cdef extern from "cpl_conv.h": + void CPLFree (void *ptr) + void CPLSetThreadLocalConfigOption (char *key, char *val) + +cdef extern from "cpl_string.h": + char ** CSLSetNameValue (char **list, char *name, char *value) + void CSLDestroy (char **list) + +cdef extern from "ogr_srs_api.h": + void OSRCleanup () + void * OSRClone (void *srs) + void OSRDestroySpatialReference (void *srs) + int OSRExportToProj4 (void *srs, char **params) + int OSRExportToWkt (void *srs, char **params) + int OSRImportFromProj4 (void *srs, char *proj) + void * OSRNewSpatialReference (char *wkt) + void OSRRelease (void *srs) + cdef extern from "gdal.h": void GDALAllRegister() + + void * GDALGetDriverByName(const char *name) void * GDALOpen(const char *filename, int access) + void GDALClose(void *ds) + void * GDALGetDatasetDriver(void *ds) + int GDALGetGeoTransform (void *ds, double *transform) + const char * GDALGetProjectionRef(void *ds) int GDALGetRasterXSize(void *ds) int GDALGetRasterYSize(void *ds) int GDALGetRasterCount(void *ds) void * GDALGetRasterBand(void *ds, int num) + int GDALSetGeoTransform (void *ds, double *transform) + int GDALSetProjection(void *ds, const char *wkt) + + int GDALGetRasterDataType(void *band) int GDALRasterIO(void *band, int access, int xoff, int yoff, int xsize, int ysize, void *buffer, int width, int height, int data_type, int poff, int loff) - void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data) - void * GDALGetDriverByName(const char *name) + + void * GDALCreate(void *driver, const char *filename, int width, int height, int nbands, int dtype, const char **options) + void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data) + const char * GDALGetDriverShortName(void *driver) + const char * GDALGetDriverLongName(void *driver) diff --git a/rasterio/_io.pyx b/rasterio/_io.pyx index 547e4db7..444c00bd 100644 --- a/rasterio/_io.pyx +++ b/rasterio/_io.pyx @@ -1,10 +1,25 @@ import logging +import os +import os.path import numpy as np cimport numpy as np ctypedef np.uint8_t DTYPE_UBYTE_t +ctypedef np.uint16_t DTYPE_UINT16_t +ctypedef np.int16_t DTYPE_INT16_t +ctypedef np.uint32_t DTYPE_UINT32_t +ctypedef np.int32_t DTYPE_INT32_t +ctypedef np.float32_t DTYPE_FLOAT32_t +ctypedef np.float64_t DTYPE_FLOAT64_t from rasterio cimport _gdal +from rasterio import dtypes + +log = logging.getLogger('rasterio') +class NullHandler(logging.Handler): + def emit(self, record): + pass +log.addHandler(NullHandler()) cdef int registered = 0 @@ -12,21 +27,100 @@ cdef void register(): _gdal.GDALAllRegister() registered = 1 +cdef int io_ubyte( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_UBYTE_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0, 0], width, height, 1, 0, 0) + +cdef int io_uint16( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_UINT16_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 2, 0, 0) + +cdef int io_int16( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_INT16_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 3, 0, 0) + +cdef int io_uint32( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_UINT32_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 4, 0, 0) + +cdef int io_int32( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_INT32_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 5, 0, 0) + +cdef int io_float32( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_FLOAT32_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 6, 0, 0) + +cdef int io_float64( + void *hband, + int mode, + int width, + int height, + np.ndarray[DTYPE_FLOAT64_t, ndim=2, mode='c'] buffer): + return _gdal.GDALRasterIO( + hband, mode, 0, 0, width, height, + &buffer[0,0], width, height, 7, 0, 0) + + cdef class RasterReader: + # Read-only access to raster data and metadata. cdef void *_hds cdef int _count cdef readonly object name + cdef readonly object mode cdef readonly object width, height cdef readonly object shape + cdef public object driver + cdef public object _dtypes cdef public object _closed + cdef public object _crs + cdef public object _transform - def __cinit__(self, path): + def __init__(self, path): self.name = path + self.mode = 'r' self._hds = NULL self._count = 0 self._closed = True + self._dtypes = [] def __dealloc__(self): self.stop() @@ -42,12 +136,75 @@ cdef class RasterReader: register() cdef const char *fname = self.name self._hds = _gdal.GDALOpen(fname, 0) + if not self._hds: + raise ValueError("Null dataset") + + cdef void *drv + cdef const char *drv_name + drv = _gdal.GDALGetDatasetDriver(self._hds) + drv_name = _gdal.GDALGetDriverShortName(drv) + self.driver = drv_name + self._count = _gdal.GDALGetRasterCount(self._hds) self.width = _gdal.GDALGetRasterXSize(self._hds) self.height = _gdal.GDALGetRasterYSize(self._hds) self.shape = (self.height, self.width) + + self._transform = self.read_transform() + self._crs = self.read_crs() + self._closed = False + def read_crs(self): + cdef char *proj_c = NULL + if self._hds is NULL: + raise ValueError("Null dataset") + #cdef const char *wkt = _gdal.GDALGetProjectionRef(self._hds) + cdef void *osr = _gdal.OSRNewSpatialReference( + _gdal.GDALGetProjectionRef(self._hds)) + log.debug("Got coordinate system") + crs = {} + if osr is not NULL: + _gdal.OSRExportToProj4(osr, &proj_c) + if proj_c is NULL: + raise ValueError("Null projection") + proj_b = proj_c + log.debug("Params: %s", proj_b) + value = proj_b.decode() + value = value.strip() + for param in value.split(): + kv = param.split("=") + if len(kv) == 2: + k, v = kv + try: + v = float(v) + if v % 1 == 0: + v = int(v) + except ValueError: + # Leave v as a string + pass + elif len(kv) == 1: + k, v = kv[0], True + else: + raise ValueError("Unexpected proj parameter %s" % param) + k = k.lstrip("+") + crs[k] = v + _gdal.CPLFree(proj_c) + _gdal.OSRDestroySpatialReference(osr) + else: + log.debug("Projection not found (cogr_crs was NULL)") + return crs + + def read_transform(self): + if self._hds is NULL: + raise ValueError("Null dataset") + cdef double gt[6] + _gdal.GDALGetGeoTransform(self._hds, gt) + transform = [0]*6 + for i in range(6): + transform[i] = gt[i] + return transform + def stop(self): if self._hds is not NULL: _gdal.GDALClose(self._hds) @@ -57,35 +214,224 @@ cdef class RasterReader: self.stop() self._closed = True - @property - def count(self): - if self._hds is not NULL: - self._count = _gdal.GDALGetRasterCount(self._hds) - return self._count - - @property - def closed(self): - return self._closed - def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() + + @property + def closed(self): + return self._closed + @property + def count(self): + if not self._count: + if not self._hds: + raise ValueError("Can't read closed raster file") + self._count = _gdal.GDALGetRasterCount(self._hds) + return self._count + + @property + def dtypes(self): + """Returns an ordered list of all band data types.""" + cdef void *hband = NULL + if not self._dtypes: + if not self._hds: + raise ValueError("Can't read closed raster file") + for i in range(self._count): + hband = _gdal.GDALGetRasterBand(self._hds, i+1) + self._dtypes.append( + dtypes.dtype_fwd[_gdal.GDALGetRasterDataType(hband)]) + return self._dtypes + + def get_crs(self): + if not self._crs: + self._crs = self.read_crs() + return self._crs + crs = property(get_crs) + + def get_transform(self): + if not self._transform: + self._transform = self.read_transform() + return self._transform + transform = property(get_transform) + def read_band(self, i, out=None): """Read the ith band into an `out` array if provided, otherwise return a new array.""" - cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1) + if not self._hds: + raise ValueError("Can't read closed raster file") + if out is not None and out.dtype != self.dtypes[i]: + raise ValueError("Band and output array dtypes do not match") + if out is not None and out.shape != self.shape: + raise ValueError("Band and output shapes do not match") + dtype = self.dtypes[i] if out is None: - out = np.zeros(self.shape, np.ubyte) - cdef np.ndarray[DTYPE_UBYTE_t, ndim=2, mode="c"] im = out - _gdal.GDALRasterIO( - hband, 0, 0, 0, self.width, self.height, - &im[0, 0], self.width, self.height, 1, 0, 0) + out = np.zeros(self.shape, dtype) + cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1) + if dtype == dtypes.ubyte: + retval = io_ubyte(hband, 0, self.width, self.height, out) + elif dtype == dtypes.uint16: + retval = io_uint16(hband, 0, self.width, self.height, out) + elif dtype == dtypes.int16: + retval = io_int16(hband, 0, self.width, self.height, out) + elif dtype == dtypes.uint32: + retval = io_uint32(hband, 0, self.width, self.height, out) + elif dtype == dtypes.int32: + retval = io_int32(hband, 0, self.width, self.height, out) + elif dtype == dtypes.float32: + retval = io_float32(hband, 0, self.width, self.height, out) + elif dtype == dtypes.float64: + retval = io_float64(hband, 0, self.width, self.height, out) + else: + raise ValueError("Invalid dtype") + # TODO: handle errors (by retval). return out -cdef class RasterUpdateSession: - pass +cdef class RasterUpdater(RasterReader): + # Read-write access to raster data and metadata. + # TODO: the r+ mode. + cdef readonly object _init_dtype + + def __init__( + self, path, mode, driver=None, + width=None, height=None, count=None, + crs=None, transform=None, dtype=None): + self.name = path + self.mode = mode + self.driver = driver + self.width = width + self.height = height + self._count = count + self._init_dtype = dtype + self._hds = NULL + self._count = count + self._crs = crs + self._transform = transform + self._closed = True + self._dtypes = [] + + def __repr__(self): + return "<%s RasterUpdater '%s' at %s>" % ( + self.closed and 'closed' or 'open', + self.name, + hex(id(self))) + + def start(self): + cdef const char *drv_name + cdef void *drv + if not registered: + register() + cdef const char *fname = self.name + if self.mode == 'w': + # Delete existing file, create. + if os.path.exists(self.name): + os.unlink(self.name) + drv_name = self.driver + drv = _gdal.GDALGetDriverByName(drv_name) + + # Find the equivalent GDAL data type or raise an exception + # We've mapped numpy scalar types to GDAL types so see + # if we can crosswalk those. + if hasattr(self._init_dtype, 'type'): + tp = self._init_dtype.type + if tp not in dtypes.dtype_rev: + raise ValueError( + "Unsupported dtype: %s" % self._init_dtype) + else: + gdal_dtype = dtypes.dtype_rev.get(tp) + else: + gdal_dtype = dtypes.dtype_rev.get(self._init_dtype) + self._hds = _gdal.GDALCreate( + drv, fname, self.width, self.height, self._count, + gdal_dtype, + NULL) + if self._transform: + self.write_transform(self._transform) + if self._crs: + self.write_crs(self._crs) + elif self.mode == 'a': + self._hds = _gdal.GDALOpen(fname, 1) + self._count = _gdal.GDALGetRasterCount(self._hds) + self.width = _gdal.GDALGetRasterXSize(self._hds) + self.height = _gdal.GDALGetRasterYSize(self._hds) + self.shape = (self.height, self.width) + self._closed = False + + def get_crs(self): + if not self._crs: + self._crs = self.read_crs() + return self._crs + + def write_crs(self, crs): + if self._hds is NULL: + raise ValueError("Can't read closed raster file") + cdef void *osr = _gdal.OSRNewSpatialReference(NULL) + if osr is NULL: + raise ValueError("Null spatial reference") + params = [] + for k, v in crs.items(): + if v is True or (k == 'no_defs' and v): + params.append("+%s" % k) + else: + params.append("+%s=%s" % (k, v)) + proj = " ".join(params) + proj_b = proj.encode() + cdef const char *proj_c = proj_b + _gdal.OSRImportFromProj4(osr, proj_c) + cdef char *wkt + _gdal.OSRExportToWkt(osr, &wkt) + _gdal.GDALSetProjection(self._hds, wkt) + _gdal.CPLFree(wkt) + _gdal.OSRDestroySpatialReference(osr) + + self._crs = crs + + crs = property(get_crs, write_crs) + + def write_transform(self, transform): + if self._hds is NULL: + raise ValueError("Can't read closed raster file") + cdef double gt[6] + for i in range(6): + gt[i] = transform[i] + retval = _gdal.GDALSetGeoTransform(self._hds, gt) + self._transform = transform + + def get_transform(self): + if not self._transform: + self._transform = self.read_transform() + return self._transform + + transform = property(get_transform, write_transform) + + def write_band(self, i, src): + """Write the src array into the ith band.""" + if not self._hds: + raise ValueError("Can't read closed raster file") + if src is not None and src.dtype != self.dtypes[i]: + raise ValueError("Band and srcput array dtypes do not match") + if src is not None and src.shape != self.shape: + raise ValueError("Band and srcput shapes do not match") + dtype = self.dtypes[i] + cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1) + if dtype == dtypes.ubyte: + retval = io_ubyte(hband, 1, self.width, self.height, src) + elif dtype == dtypes.uint16: + retval = io_uint16(hband, 1, self.width, self.height, src) + elif dtype == dtypes.int16: + retval = io_int16(hband, 1, self.width, self.height, src) + elif dtype == dtypes.uint32: + retval = io_uint32(hband, 1, self.width, self.height, src) + elif dtype == dtypes.int32: + retval = io_int32(hband, 1, self.width, self.height, src) + elif dtype == dtypes.float32: + retval = io_float32(hband, 1, self.width, self.height, src) + elif dtype == dtypes.float64: + retval = io_float64(hband, 1, self.width, self.height, src) + else: + raise ValueError("Invalid dtype") + # TODO: handle errors (by retval). diff --git a/rasterio/dtypes.py b/rasterio/dtypes.py new file mode 100644 index 00000000..8a0fa475 --- /dev/null +++ b/rasterio/dtypes.py @@ -0,0 +1,27 @@ + +import numpy + +ubyte = uint8 = numpy.uint8 +uint16 = numpy.uint16 +int16 = numpy.int16 +uint32 = numpy.uint32 +int32 = numpy.int32 +float32 = numpy.float32 +float64 = numpy.float64 + +# Not supported: +# GDT_CInt16 = 8, GDT_CInt32 = 9, GDT_CFloat32 = 10, GDT_CFloat64 = 11 + +dtype_fwd = { + 0: None, # GDT_Unknown + 1: ubyte, # GDT_Byte + 2: uint16, # GDT_UInt16 + 3: int16, # GDT_Int16 + 4: uint32, # GDT_UInt32 + 5: int32, # GDT_Int32 + 6: float32, # GDT_Float32 + 7: float64 } # GDT_Float64 + +dtype_rev = {v: k for k, v in dtype_fwd.items()} +dtype_rev[uint8] = 1 + diff --git a/rasterio/tests/__init__.py b/rasterio/tests/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/rasterio/tests/__init__.py @@ -0,0 +1 @@ +# diff --git a/rasterio/tests/data/RGB.byte.tif.aux.xml b/rasterio/tests/data/RGB.byte.tif.aux.xml new file mode 100644 index 00000000..30692ee1 --- /dev/null +++ b/rasterio/tests/data/RGB.byte.tif.aux.xml @@ -0,0 +1,26 @@ + + + + 255 + 29.947726688477 + 0 + 52.340921626611 + + + + + 255 + 44.516147889382 + 0 + 56.934318291876 + + + + + 255 + 48.113056354743 + 0 + 60.112778509941 + + + diff --git a/rasterio/tests/test_dtypes.py b/rasterio/tests/test_dtypes.py new file mode 100644 index 00000000..87b94ee0 --- /dev/null +++ b/rasterio/tests/test_dtypes.py @@ -0,0 +1,9 @@ +import numpy + +import rasterio + +def test_np_dt_uint8(): + assert rasterio.check_dtype(numpy.dtype(numpy.uint8)) +def test_dt_ubyte(): + assert rasterio.check_dtype(numpy.dtype(rasterio.ubyte)) + diff --git a/rasterio/tests/test_read.py b/rasterio/tests/test_read.py index ad073a41..1e28b6af 100644 --- a/rasterio/tests/test_read.py +++ b/rasterio/tests/test_read.py @@ -1,16 +1,26 @@ import unittest +import numpy + import rasterio class ReaderContextTest(unittest.TestCase): def test_context(self): with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s: self.assertEqual(s.name, 'rasterio/tests/data/RGB.byte.tif') + self.assertEqual(s.driver, 'GTiff') self.assertEqual(s.closed, False) self.assertEqual(s.count, 3) self.assertEqual(s.width, 791) self.assertEqual(s.height, 718) self.assertEqual(s.shape, (718, 791)) + self.assertEqual(s.dtypes, [rasterio.ubyte]*3) + self.assertEqual(s.crs['proj'], 'utm') + self.assertEqual(s.crs['zone'], 18) + self.assertEqual( + s.transform, + [101985.0, 300.0379266750948, 0.0, + 2826915.0, 0.0, -300.041782729805]) self.assertEqual( repr(s), "" % hex(id(s))) + def test_read_ubyte(self): + with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s: + a = s.read_band(0) + self.assertEqual(a.dtype, rasterio.ubyte) + def test_read_ubyte_out(self): + with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s: + a = numpy.zeros((718, 791), dtype=rasterio.ubyte) + a = s.read_band(0, a) + self.assertEqual(a.dtype, rasterio.ubyte) + def test_read_out_dtype_fail(self): + with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s: + a = numpy.zeros((718, 791), dtype=rasterio.float32) + self.assertRaises(ValueError, s.read_band, 0, a) + def test_read_out_shape_fail(self): + with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s: + a = numpy.zeros((42, 42), dtype=rasterio.ubyte) + self.assertRaises(ValueError, s.read_band, 0, a) diff --git a/rasterio/tests/test_write.py b/rasterio/tests/test_write.py new file mode 100644 index 00000000..935333ec --- /dev/null +++ b/rasterio/tests/test_write.py @@ -0,0 +1,89 @@ +import os.path +import unittest +import shutil +import subprocess +import tempfile + +import numpy + +import rasterio + +class WriterContextTest(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + def tearDown(self): + shutil.rmtree(self.tempdir) + def test_context(self): + name = os.path.join(self.tempdir, "test_context.tif") + with rasterio.open( + name, 'w', + driver='GTiff', width=100, height=100, count=1, + dtype=rasterio.ubyte) as s: + self.assertEqual(s.name, name) + self.assertEqual(s.driver, 'GTiff') + self.assertEqual(s.closed, False) + self.assertEqual(s.count, 1) + self.assertEqual(s.width, 100) + self.assertEqual(s.height, 100) + self.assertEqual(s.shape, (100, 100)) + self.assertEqual( + repr(s), + "" % (name, hex(id(s)))) + self.assertEqual(s.closed, True) + self.assertEqual(s.count, 1) + self.assertEqual(s.width, 100) + self.assertEqual(s.height, 100) + self.assertEqual(s.shape, (100, 100)) + self.assertEqual( + repr(s), + "" % (name, hex(id(s)))) + info = subprocess.check_output(["gdalinfo", name]) + self.assert_("GTiff" in info) + self.assert_( + "Size is 100, 100" in info) + self.assert_( + "Band 1 Block=100x81 Type=Byte, ColorInterp=Gray" in info) + def test_write_ubyte(self): + name = os.path.join(self.tempdir, "test_write_ubyte.tif") + a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127 + with rasterio.open( + name, 'w', + driver='GTiff', width=100, height=100, count=1, + dtype=a.dtype) as s: + s.write_band(0, a) + info = subprocess.check_output(["gdalinfo", "-stats", name]) + self.assert_( + "Minimum=127.000, Maximum=127.000, " + "Mean=127.000, StdDev=0.000" in info, + info) + def test_write_float(self): + name = os.path.join(self.tempdir, "test_write_float.tif") + a = numpy.ones((100, 100), dtype=rasterio.float32) * 42.0 + with rasterio.open( + name, 'w', + driver='GTiff', width=100, height=100, count=2, + dtype=rasterio.float32) as s: + self.assertEqual(s.dtypes, [rasterio.float32]*2) + s.write_band(0, a) + s.write_band(1, a) + info = subprocess.check_output(["gdalinfo", "-stats", name]) + self.assert_( + "Minimum=42.000, Maximum=42.000, " + "Mean=42.000, StdDev=0.000" in info, + info) + def test_write_crs_transform(self): + name = os.path.join(self.tempdir, "test_write_crs_transform.tif") + a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127 + with rasterio.open( + name, 'w', + driver='GTiff', width=100, height=100, count=1, + crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', + 'proj': 'utm', 'zone': 18}, + transform=[101985.0, 300.0379266750948, 0.0, + 2826915.0, 0.0, -300.041782729805], + dtype=rasterio.ubyte) as s: + s.write_band(0, a) + info = subprocess.check_output(["gdalinfo", name]) + self.assert_('PROJCS["UTM Zone 18, Northern Hemisphere",' in info) + self.assert_("(300.037926675094809,-300.041782729804993)" in info) +