mirror of
https://github.com/rasterio/rasterio.git
synced 2025-12-08 17:36:12 +00:00
Basic read and write of GeoTIFFs with plenty of tests.
And an example script, which also appears in the readme.
This commit is contained in:
parent
a0c962b6d4
commit
5bcda751d0
31
README.md
31
README.md
@ -14,19 +14,30 @@ Example
|
||||
|
||||
Here's an example of the features fasterio aims to provide.
|
||||
|
||||
import numpy
|
||||
import rasterio
|
||||
|
||||
# Read raster bands directly into provided Numpy arrays.
|
||||
with rasterio.open('reflectance.tif') as src:
|
||||
vis = src.read_band(2, numpy.zeros((src.shape), numpy.float))
|
||||
nir = src.read_band(3, numpy.zeros((src.shape), numpy.float))
|
||||
import subprocess
|
||||
|
||||
# Read raster bands directly to Numpy arrays.
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src:
|
||||
r = src.read_band(0).astype(rasterio.float32)
|
||||
g = src.read_band(1).astype(rasterio.float32)
|
||||
b = src.read_band(2).astype(rasterio.float32)
|
||||
|
||||
# Combine arrays using the 'add' ufunc and then convert back to btyes.
|
||||
total = (r + g + b)/3.0
|
||||
total = total.astype(rasterio.ubyte)
|
||||
|
||||
ndvi = (nir-vis)/(nir+vis)
|
||||
|
||||
# Write the product as a raster band to a new file.
|
||||
with rasterio.open('ndvi.tif', 'w') as dst:
|
||||
dst.append_band(ndvi)
|
||||
with rasterio.open(
|
||||
'/tmp/total.tif', 'w',
|
||||
driver='GTiff',
|
||||
width=src.width, height=src.height, count=1,
|
||||
crs=src.crs, transform=src.transform,
|
||||
dtype=total.dtype) as dst:
|
||||
dst.write_band(0, total)
|
||||
|
||||
info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif'])
|
||||
print(info)
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
|
||||
26
examples/total.py
Normal file
26
examples/total.py
Normal file
@ -0,0 +1,26 @@
|
||||
import rasterio
|
||||
|
||||
# Read raster bands directly to Numpy arrays.
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src:
|
||||
r = src.read_band(0).astype(rasterio.float32)
|
||||
g = src.read_band(1).astype(rasterio.float32)
|
||||
b = src.read_band(2).astype(rasterio.float32)
|
||||
|
||||
# Combine arrays using the 'add' ufunc and then convert back to btyes.
|
||||
total = (r + g + b)/3.0
|
||||
total = total.astype(rasterio.ubyte)
|
||||
|
||||
# Write the product as a raster band to a new file.
|
||||
with rasterio.open(
|
||||
'/tmp/total.tif', 'w',
|
||||
driver='GTiff',
|
||||
width=src.width, height=src.height, count=1,
|
||||
crs=src.crs, transform=src.transform,
|
||||
dtype=total.dtype) as dst:
|
||||
dst.write_band(0, total)
|
||||
|
||||
import subprocess
|
||||
info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif'])
|
||||
print(info)
|
||||
|
||||
|
||||
@ -4,28 +4,75 @@ import os
|
||||
|
||||
from six import string_types
|
||||
|
||||
from rasterio._io import RasterReader, RasterUpdateSession
|
||||
from rasterio._io import RasterReader, RasterUpdater
|
||||
import rasterio.dtypes
|
||||
from rasterio.dtypes import (
|
||||
ubyte, uint8, uint16, int16, uint32, int32, float32, float64)
|
||||
|
||||
def open(path, mode='r', driver=None):
|
||||
"""."""
|
||||
def open(
|
||||
path, mode='r',
|
||||
driver=None,
|
||||
width=None, height=None,
|
||||
count=None,
|
||||
dtype=None,
|
||||
crs=None, transform=None):
|
||||
"""Open file at ``path`` in ``mode`` "r" (read), "r+" (read/write),
|
||||
or "w" (write) and return a ``Reader`` or ``Updater`` object.
|
||||
|
||||
In write mode, a driver name such as "GTiff" or "JPEG" (see GDAL
|
||||
docs or ``gdan_translate --help`` on the command line), ``width``
|
||||
(number of pixels per line) and ``height`` (number of lines), the
|
||||
``count`` number of bands in the new file must be specified.
|
||||
Additionally, the data type for bands such as ``rasterio.ubyte`` for
|
||||
8-bit bands or ``rasterio.uint16`` for 16-bit bands must be
|
||||
specified using the ``dtype`` argument.
|
||||
|
||||
A coordinate reference system for raster datasets in write mode can
|
||||
be defined by the ``crs`` argument. It takes Proj4 style mappings
|
||||
like
|
||||
|
||||
{'proj': 'longlat', 'ellps': 'WGS84', 'datum': 'WGS84',
|
||||
'no_defs': True}
|
||||
|
||||
A geo-transform matrix that maps pixel coordinates to coordinates in
|
||||
the specified crs should be specified using the ``transform``
|
||||
argument. This matrix is represented by a six-element sequence.
|
||||
|
||||
Item 0: the top left x value
|
||||
Item 1: W-E pixel resolution
|
||||
Item 2: rotation, 0 if the image is "north up"
|
||||
Item 3: top left y value
|
||||
Item 4: rotation, 0 if the image is "north up"
|
||||
Item 5: N-S pixel resolution (usually a negative number)
|
||||
"""
|
||||
if not isinstance(path, string_types):
|
||||
raise TypeError("invalid path: %r" % path)
|
||||
if mode and not isinstance(mode, string_types):
|
||||
raise TypeError("invalid mode: %r" % mode)
|
||||
if driver and not isinstance(driver, string_types):
|
||||
raise TypeError("invalid driver: %r" % driver)
|
||||
if mode in ('a', 'r'):
|
||||
if mode in ('r', 'r+'):
|
||||
if not os.path.exists(path):
|
||||
raise IOError("no such file or directory: %r" % path)
|
||||
if mode == 'a':
|
||||
s = RasterUpdateSession(path, mode, driver=None)
|
||||
elif mode == 'r':
|
||||
|
||||
if mode == 'r':
|
||||
s = RasterReader(path)
|
||||
elif mode == 'r+':
|
||||
raise NotImplemented("r+ mode not implemented")
|
||||
# s = RasterUpdater(path, mode, driver=None)
|
||||
elif mode == 'w':
|
||||
s = RasterUpdateSession(path, mode, driver=driver)
|
||||
s = RasterUpdater(
|
||||
path, mode, driver,
|
||||
width, height, count,
|
||||
crs, transform, dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"mode string must be one of 'r', 'w', or 'a', not %s" % mode)
|
||||
"mode string must be one of 'r', 'r+', or 'w', not %s" % mode)
|
||||
|
||||
s.start()
|
||||
return s
|
||||
|
||||
def check_dtype(dt):
|
||||
tp = getattr(dt, 'type', dt)
|
||||
return tp in rasterio.dtypes.dtype_rev
|
||||
|
||||
|
||||
@ -1,14 +1,46 @@
|
||||
# GDAL function definitions.
|
||||
#
|
||||
|
||||
cdef extern from "cpl_conv.h":
|
||||
void CPLFree (void *ptr)
|
||||
void CPLSetThreadLocalConfigOption (char *key, char *val)
|
||||
|
||||
cdef extern from "cpl_string.h":
|
||||
char ** CSLSetNameValue (char **list, char *name, char *value)
|
||||
void CSLDestroy (char **list)
|
||||
|
||||
cdef extern from "ogr_srs_api.h":
|
||||
void OSRCleanup ()
|
||||
void * OSRClone (void *srs)
|
||||
void OSRDestroySpatialReference (void *srs)
|
||||
int OSRExportToProj4 (void *srs, char **params)
|
||||
int OSRExportToWkt (void *srs, char **params)
|
||||
int OSRImportFromProj4 (void *srs, char *proj)
|
||||
void * OSRNewSpatialReference (char *wkt)
|
||||
void OSRRelease (void *srs)
|
||||
|
||||
cdef extern from "gdal.h":
|
||||
void GDALAllRegister()
|
||||
|
||||
void * GDALGetDriverByName(const char *name)
|
||||
void * GDALOpen(const char *filename, int access)
|
||||
|
||||
void GDALClose(void *ds)
|
||||
void * GDALGetDatasetDriver(void *ds)
|
||||
int GDALGetGeoTransform (void *ds, double *transform)
|
||||
const char * GDALGetProjectionRef(void *ds)
|
||||
int GDALGetRasterXSize(void *ds)
|
||||
int GDALGetRasterYSize(void *ds)
|
||||
int GDALGetRasterCount(void *ds)
|
||||
void * GDALGetRasterBand(void *ds, int num)
|
||||
int GDALSetGeoTransform (void *ds, double *transform)
|
||||
int GDALSetProjection(void *ds, const char *wkt)
|
||||
|
||||
int GDALGetRasterDataType(void *band)
|
||||
int GDALRasterIO(void *band, int access, int xoff, int yoff, int xsize, int ysize, void *buffer, int width, int height, int data_type, int poff, int loff)
|
||||
void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data)
|
||||
void * GDALGetDriverByName(const char *name)
|
||||
|
||||
void * GDALCreate(void *driver, const char *filename, int width, int height, int nbands, int dtype, const char **options)
|
||||
void * GDALCreateCopy(void *driver, const char *filename, void *ds, int strict, char **options, void *progress_func, void *progress_data)
|
||||
const char * GDALGetDriverShortName(void *driver)
|
||||
const char * GDALGetDriverLongName(void *driver)
|
||||
|
||||
|
||||
384
rasterio/_io.pyx
384
rasterio/_io.pyx
@ -1,10 +1,25 @@
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
ctypedef np.uint8_t DTYPE_UBYTE_t
|
||||
ctypedef np.uint16_t DTYPE_UINT16_t
|
||||
ctypedef np.int16_t DTYPE_INT16_t
|
||||
ctypedef np.uint32_t DTYPE_UINT32_t
|
||||
ctypedef np.int32_t DTYPE_INT32_t
|
||||
ctypedef np.float32_t DTYPE_FLOAT32_t
|
||||
ctypedef np.float64_t DTYPE_FLOAT64_t
|
||||
|
||||
from rasterio cimport _gdal
|
||||
from rasterio import dtypes
|
||||
|
||||
log = logging.getLogger('rasterio')
|
||||
class NullHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
pass
|
||||
log.addHandler(NullHandler())
|
||||
|
||||
cdef int registered = 0
|
||||
|
||||
@ -12,21 +27,100 @@ cdef void register():
|
||||
_gdal.GDALAllRegister()
|
||||
registered = 1
|
||||
|
||||
cdef int io_ubyte(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_UBYTE_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0, 0], width, height, 1, 0, 0)
|
||||
|
||||
cdef int io_uint16(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_UINT16_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 2, 0, 0)
|
||||
|
||||
cdef int io_int16(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_INT16_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 3, 0, 0)
|
||||
|
||||
cdef int io_uint32(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_UINT32_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 4, 0, 0)
|
||||
|
||||
cdef int io_int32(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_INT32_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 5, 0, 0)
|
||||
|
||||
cdef int io_float32(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_FLOAT32_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 6, 0, 0)
|
||||
|
||||
cdef int io_float64(
|
||||
void *hband,
|
||||
int mode,
|
||||
int width,
|
||||
int height,
|
||||
np.ndarray[DTYPE_FLOAT64_t, ndim=2, mode='c'] buffer):
|
||||
return _gdal.GDALRasterIO(
|
||||
hband, mode, 0, 0, width, height,
|
||||
&buffer[0,0], width, height, 7, 0, 0)
|
||||
|
||||
|
||||
cdef class RasterReader:
|
||||
# Read-only access to raster data and metadata.
|
||||
|
||||
cdef void *_hds
|
||||
cdef int _count
|
||||
|
||||
cdef readonly object name
|
||||
cdef readonly object mode
|
||||
cdef readonly object width, height
|
||||
cdef readonly object shape
|
||||
cdef public object driver
|
||||
cdef public object _dtypes
|
||||
cdef public object _closed
|
||||
cdef public object _crs
|
||||
cdef public object _transform
|
||||
|
||||
def __cinit__(self, path):
|
||||
def __init__(self, path):
|
||||
self.name = path
|
||||
self.mode = 'r'
|
||||
self._hds = NULL
|
||||
self._count = 0
|
||||
self._closed = True
|
||||
self._dtypes = []
|
||||
|
||||
def __dealloc__(self):
|
||||
self.stop()
|
||||
@ -42,12 +136,75 @@ cdef class RasterReader:
|
||||
register()
|
||||
cdef const char *fname = self.name
|
||||
self._hds = _gdal.GDALOpen(fname, 0)
|
||||
if not self._hds:
|
||||
raise ValueError("Null dataset")
|
||||
|
||||
cdef void *drv
|
||||
cdef const char *drv_name
|
||||
drv = _gdal.GDALGetDatasetDriver(self._hds)
|
||||
drv_name = _gdal.GDALGetDriverShortName(drv)
|
||||
self.driver = drv_name
|
||||
|
||||
self._count = _gdal.GDALGetRasterCount(self._hds)
|
||||
self.width = _gdal.GDALGetRasterXSize(self._hds)
|
||||
self.height = _gdal.GDALGetRasterYSize(self._hds)
|
||||
self.shape = (self.height, self.width)
|
||||
|
||||
self._transform = self.read_transform()
|
||||
self._crs = self.read_crs()
|
||||
|
||||
self._closed = False
|
||||
|
||||
def read_crs(self):
|
||||
cdef char *proj_c = NULL
|
||||
if self._hds is NULL:
|
||||
raise ValueError("Null dataset")
|
||||
#cdef const char *wkt = _gdal.GDALGetProjectionRef(self._hds)
|
||||
cdef void *osr = _gdal.OSRNewSpatialReference(
|
||||
_gdal.GDALGetProjectionRef(self._hds))
|
||||
log.debug("Got coordinate system")
|
||||
crs = {}
|
||||
if osr is not NULL:
|
||||
_gdal.OSRExportToProj4(osr, &proj_c)
|
||||
if proj_c is NULL:
|
||||
raise ValueError("Null projection")
|
||||
proj_b = proj_c
|
||||
log.debug("Params: %s", proj_b)
|
||||
value = proj_b.decode()
|
||||
value = value.strip()
|
||||
for param in value.split():
|
||||
kv = param.split("=")
|
||||
if len(kv) == 2:
|
||||
k, v = kv
|
||||
try:
|
||||
v = float(v)
|
||||
if v % 1 == 0:
|
||||
v = int(v)
|
||||
except ValueError:
|
||||
# Leave v as a string
|
||||
pass
|
||||
elif len(kv) == 1:
|
||||
k, v = kv[0], True
|
||||
else:
|
||||
raise ValueError("Unexpected proj parameter %s" % param)
|
||||
k = k.lstrip("+")
|
||||
crs[k] = v
|
||||
_gdal.CPLFree(proj_c)
|
||||
_gdal.OSRDestroySpatialReference(osr)
|
||||
else:
|
||||
log.debug("Projection not found (cogr_crs was NULL)")
|
||||
return crs
|
||||
|
||||
def read_transform(self):
|
||||
if self._hds is NULL:
|
||||
raise ValueError("Null dataset")
|
||||
cdef double gt[6]
|
||||
_gdal.GDALGetGeoTransform(self._hds, gt)
|
||||
transform = [0]*6
|
||||
for i in range(6):
|
||||
transform[i] = gt[i]
|
||||
return transform
|
||||
|
||||
def stop(self):
|
||||
if self._hds is not NULL:
|
||||
_gdal.GDALClose(self._hds)
|
||||
@ -57,35 +214,224 @@ cdef class RasterReader:
|
||||
self.stop()
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
if self._hds is not NULL:
|
||||
self._count = _gdal.GDALGetRasterCount(self._hds)
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
if not self._count:
|
||||
if not self._hds:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
self._count = _gdal.GDALGetRasterCount(self._hds)
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
"""Returns an ordered list of all band data types."""
|
||||
cdef void *hband = NULL
|
||||
if not self._dtypes:
|
||||
if not self._hds:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
for i in range(self._count):
|
||||
hband = _gdal.GDALGetRasterBand(self._hds, i+1)
|
||||
self._dtypes.append(
|
||||
dtypes.dtype_fwd[_gdal.GDALGetRasterDataType(hband)])
|
||||
return self._dtypes
|
||||
|
||||
def get_crs(self):
|
||||
if not self._crs:
|
||||
self._crs = self.read_crs()
|
||||
return self._crs
|
||||
crs = property(get_crs)
|
||||
|
||||
def get_transform(self):
|
||||
if not self._transform:
|
||||
self._transform = self.read_transform()
|
||||
return self._transform
|
||||
transform = property(get_transform)
|
||||
|
||||
def read_band(self, i, out=None):
|
||||
"""Read the ith band into an `out` array if provided, otherwise
|
||||
return a new array."""
|
||||
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
|
||||
if not self._hds:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
if out is not None and out.dtype != self.dtypes[i]:
|
||||
raise ValueError("Band and output array dtypes do not match")
|
||||
if out is not None and out.shape != self.shape:
|
||||
raise ValueError("Band and output shapes do not match")
|
||||
dtype = self.dtypes[i]
|
||||
if out is None:
|
||||
out = np.zeros(self.shape, np.ubyte)
|
||||
cdef np.ndarray[DTYPE_UBYTE_t, ndim=2, mode="c"] im = out
|
||||
_gdal.GDALRasterIO(
|
||||
hband, 0, 0, 0, self.width, self.height,
|
||||
&im[0, 0], self.width, self.height, 1, 0, 0)
|
||||
out = np.zeros(self.shape, dtype)
|
||||
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
|
||||
if dtype == dtypes.ubyte:
|
||||
retval = io_ubyte(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.uint16:
|
||||
retval = io_uint16(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.int16:
|
||||
retval = io_int16(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.uint32:
|
||||
retval = io_uint32(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.int32:
|
||||
retval = io_int32(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.float32:
|
||||
retval = io_float32(hband, 0, self.width, self.height, out)
|
||||
elif dtype == dtypes.float64:
|
||||
retval = io_float64(hband, 0, self.width, self.height, out)
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
# TODO: handle errors (by retval).
|
||||
return out
|
||||
|
||||
cdef class RasterUpdateSession:
|
||||
pass
|
||||
|
||||
cdef class RasterUpdater(RasterReader):
|
||||
# Read-write access to raster data and metadata.
|
||||
# TODO: the r+ mode.
|
||||
cdef readonly object _init_dtype
|
||||
|
||||
def __init__(
|
||||
self, path, mode, driver=None,
|
||||
width=None, height=None, count=None,
|
||||
crs=None, transform=None, dtype=None):
|
||||
self.name = path
|
||||
self.mode = mode
|
||||
self.driver = driver
|
||||
self.width = width
|
||||
self.height = height
|
||||
self._count = count
|
||||
self._init_dtype = dtype
|
||||
self._hds = NULL
|
||||
self._count = count
|
||||
self._crs = crs
|
||||
self._transform = transform
|
||||
self._closed = True
|
||||
self._dtypes = []
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s RasterUpdater '%s' at %s>" % (
|
||||
self.closed and 'closed' or 'open',
|
||||
self.name,
|
||||
hex(id(self)))
|
||||
|
||||
def start(self):
|
||||
cdef const char *drv_name
|
||||
cdef void *drv
|
||||
if not registered:
|
||||
register()
|
||||
cdef const char *fname = self.name
|
||||
if self.mode == 'w':
|
||||
# Delete existing file, create.
|
||||
if os.path.exists(self.name):
|
||||
os.unlink(self.name)
|
||||
drv_name = self.driver
|
||||
drv = _gdal.GDALGetDriverByName(drv_name)
|
||||
|
||||
# Find the equivalent GDAL data type or raise an exception
|
||||
# We've mapped numpy scalar types to GDAL types so see
|
||||
# if we can crosswalk those.
|
||||
if hasattr(self._init_dtype, 'type'):
|
||||
tp = self._init_dtype.type
|
||||
if tp not in dtypes.dtype_rev:
|
||||
raise ValueError(
|
||||
"Unsupported dtype: %s" % self._init_dtype)
|
||||
else:
|
||||
gdal_dtype = dtypes.dtype_rev.get(tp)
|
||||
else:
|
||||
gdal_dtype = dtypes.dtype_rev.get(self._init_dtype)
|
||||
self._hds = _gdal.GDALCreate(
|
||||
drv, fname, self.width, self.height, self._count,
|
||||
gdal_dtype,
|
||||
NULL)
|
||||
if self._transform:
|
||||
self.write_transform(self._transform)
|
||||
if self._crs:
|
||||
self.write_crs(self._crs)
|
||||
elif self.mode == 'a':
|
||||
self._hds = _gdal.GDALOpen(fname, 1)
|
||||
self._count = _gdal.GDALGetRasterCount(self._hds)
|
||||
self.width = _gdal.GDALGetRasterXSize(self._hds)
|
||||
self.height = _gdal.GDALGetRasterYSize(self._hds)
|
||||
self.shape = (self.height, self.width)
|
||||
self._closed = False
|
||||
|
||||
def get_crs(self):
|
||||
if not self._crs:
|
||||
self._crs = self.read_crs()
|
||||
return self._crs
|
||||
|
||||
def write_crs(self, crs):
|
||||
if self._hds is NULL:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
cdef void *osr = _gdal.OSRNewSpatialReference(NULL)
|
||||
if osr is NULL:
|
||||
raise ValueError("Null spatial reference")
|
||||
params = []
|
||||
for k, v in crs.items():
|
||||
if v is True or (k == 'no_defs' and v):
|
||||
params.append("+%s" % k)
|
||||
else:
|
||||
params.append("+%s=%s" % (k, v))
|
||||
proj = " ".join(params)
|
||||
proj_b = proj.encode()
|
||||
cdef const char *proj_c = proj_b
|
||||
_gdal.OSRImportFromProj4(osr, proj_c)
|
||||
cdef char *wkt
|
||||
_gdal.OSRExportToWkt(osr, &wkt)
|
||||
_gdal.GDALSetProjection(self._hds, wkt)
|
||||
_gdal.CPLFree(wkt)
|
||||
_gdal.OSRDestroySpatialReference(osr)
|
||||
|
||||
self._crs = crs
|
||||
|
||||
crs = property(get_crs, write_crs)
|
||||
|
||||
def write_transform(self, transform):
|
||||
if self._hds is NULL:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
cdef double gt[6]
|
||||
for i in range(6):
|
||||
gt[i] = transform[i]
|
||||
retval = _gdal.GDALSetGeoTransform(self._hds, gt)
|
||||
self._transform = transform
|
||||
|
||||
def get_transform(self):
|
||||
if not self._transform:
|
||||
self._transform = self.read_transform()
|
||||
return self._transform
|
||||
|
||||
transform = property(get_transform, write_transform)
|
||||
|
||||
def write_band(self, i, src):
|
||||
"""Write the src array into the ith band."""
|
||||
if not self._hds:
|
||||
raise ValueError("Can't read closed raster file")
|
||||
if src is not None and src.dtype != self.dtypes[i]:
|
||||
raise ValueError("Band and srcput array dtypes do not match")
|
||||
if src is not None and src.shape != self.shape:
|
||||
raise ValueError("Band and srcput shapes do not match")
|
||||
dtype = self.dtypes[i]
|
||||
cdef void *hband = _gdal.GDALGetRasterBand(self._hds, i+1)
|
||||
if dtype == dtypes.ubyte:
|
||||
retval = io_ubyte(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.uint16:
|
||||
retval = io_uint16(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.int16:
|
||||
retval = io_int16(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.uint32:
|
||||
retval = io_uint32(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.int32:
|
||||
retval = io_int32(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.float32:
|
||||
retval = io_float32(hband, 1, self.width, self.height, src)
|
||||
elif dtype == dtypes.float64:
|
||||
retval = io_float64(hband, 1, self.width, self.height, src)
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
# TODO: handle errors (by retval).
|
||||
|
||||
|
||||
27
rasterio/dtypes.py
Normal file
27
rasterio/dtypes.py
Normal file
@ -0,0 +1,27 @@
|
||||
|
||||
import numpy
|
||||
|
||||
ubyte = uint8 = numpy.uint8
|
||||
uint16 = numpy.uint16
|
||||
int16 = numpy.int16
|
||||
uint32 = numpy.uint32
|
||||
int32 = numpy.int32
|
||||
float32 = numpy.float32
|
||||
float64 = numpy.float64
|
||||
|
||||
# Not supported:
|
||||
# GDT_CInt16 = 8, GDT_CInt32 = 9, GDT_CFloat32 = 10, GDT_CFloat64 = 11
|
||||
|
||||
dtype_fwd = {
|
||||
0: None, # GDT_Unknown
|
||||
1: ubyte, # GDT_Byte
|
||||
2: uint16, # GDT_UInt16
|
||||
3: int16, # GDT_Int16
|
||||
4: uint32, # GDT_UInt32
|
||||
5: int32, # GDT_Int32
|
||||
6: float32, # GDT_Float32
|
||||
7: float64 } # GDT_Float64
|
||||
|
||||
dtype_rev = {v: k for k, v in dtype_fwd.items()}
|
||||
dtype_rev[uint8] = 1
|
||||
|
||||
1
rasterio/tests/__init__.py
Normal file
1
rasterio/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
#
|
||||
26
rasterio/tests/data/RGB.byte.tif.aux.xml
Normal file
26
rasterio/tests/data/RGB.byte.tif.aux.xml
Normal file
@ -0,0 +1,26 @@
|
||||
<PAMDataset>
|
||||
<PAMRasterBand band="1">
|
||||
<Metadata>
|
||||
<MDI key="STATISTICS_MAXIMUM">255</MDI>
|
||||
<MDI key="STATISTICS_MEAN">29.947726688477</MDI>
|
||||
<MDI key="STATISTICS_MINIMUM">0</MDI>
|
||||
<MDI key="STATISTICS_STDDEV">52.340921626611</MDI>
|
||||
</Metadata>
|
||||
</PAMRasterBand>
|
||||
<PAMRasterBand band="2">
|
||||
<Metadata>
|
||||
<MDI key="STATISTICS_MAXIMUM">255</MDI>
|
||||
<MDI key="STATISTICS_MEAN">44.516147889382</MDI>
|
||||
<MDI key="STATISTICS_MINIMUM">0</MDI>
|
||||
<MDI key="STATISTICS_STDDEV">56.934318291876</MDI>
|
||||
</Metadata>
|
||||
</PAMRasterBand>
|
||||
<PAMRasterBand band="3">
|
||||
<Metadata>
|
||||
<MDI key="STATISTICS_MAXIMUM">255</MDI>
|
||||
<MDI key="STATISTICS_MEAN">48.113056354743</MDI>
|
||||
<MDI key="STATISTICS_MINIMUM">0</MDI>
|
||||
<MDI key="STATISTICS_STDDEV">60.112778509941</MDI>
|
||||
</Metadata>
|
||||
</PAMRasterBand>
|
||||
</PAMDataset>
|
||||
9
rasterio/tests/test_dtypes.py
Normal file
9
rasterio/tests/test_dtypes.py
Normal file
@ -0,0 +1,9 @@
|
||||
import numpy
|
||||
|
||||
import rasterio
|
||||
|
||||
def test_np_dt_uint8():
|
||||
assert rasterio.check_dtype(numpy.dtype(numpy.uint8))
|
||||
def test_dt_ubyte():
|
||||
assert rasterio.check_dtype(numpy.dtype(rasterio.ubyte))
|
||||
|
||||
@ -1,16 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
|
||||
import rasterio
|
||||
|
||||
class ReaderContextTest(unittest.TestCase):
|
||||
def test_context(self):
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
|
||||
self.assertEqual(s.name, 'rasterio/tests/data/RGB.byte.tif')
|
||||
self.assertEqual(s.driver, 'GTiff')
|
||||
self.assertEqual(s.closed, False)
|
||||
self.assertEqual(s.count, 3)
|
||||
self.assertEqual(s.width, 791)
|
||||
self.assertEqual(s.height, 718)
|
||||
self.assertEqual(s.shape, (718, 791))
|
||||
self.assertEqual(s.dtypes, [rasterio.ubyte]*3)
|
||||
self.assertEqual(s.crs['proj'], 'utm')
|
||||
self.assertEqual(s.crs['zone'], 18)
|
||||
self.assertEqual(
|
||||
s.transform,
|
||||
[101985.0, 300.0379266750948, 0.0,
|
||||
2826915.0, 0.0, -300.041782729805])
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
"<open RasterReader 'rasterio/tests/data/RGB.byte.tif' "
|
||||
@ -20,8 +30,32 @@ class ReaderContextTest(unittest.TestCase):
|
||||
self.assertEqual(s.width, 791)
|
||||
self.assertEqual(s.height, 718)
|
||||
self.assertEqual(s.shape, (718, 791))
|
||||
self.assertEqual(s.dtypes, [rasterio.ubyte]*3)
|
||||
self.assertEqual(s.crs['proj'], 'utm')
|
||||
self.assertEqual(s.crs['zone'], 18)
|
||||
self.assertEqual(
|
||||
s.transform,
|
||||
[101985.0, 300.0379266750948, 0.0,
|
||||
2826915.0, 0.0, -300.041782729805])
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
"<closed RasterReader 'rasterio/tests/data/RGB.byte.tif' "
|
||||
"at %s>" % hex(id(s)))
|
||||
def test_read_ubyte(self):
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
|
||||
a = s.read_band(0)
|
||||
self.assertEqual(a.dtype, rasterio.ubyte)
|
||||
def test_read_ubyte_out(self):
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
|
||||
a = numpy.zeros((718, 791), dtype=rasterio.ubyte)
|
||||
a = s.read_band(0, a)
|
||||
self.assertEqual(a.dtype, rasterio.ubyte)
|
||||
def test_read_out_dtype_fail(self):
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
|
||||
a = numpy.zeros((718, 791), dtype=rasterio.float32)
|
||||
self.assertRaises(ValueError, s.read_band, 0, a)
|
||||
def test_read_out_shape_fail(self):
|
||||
with rasterio.open('rasterio/tests/data/RGB.byte.tif') as s:
|
||||
a = numpy.zeros((42, 42), dtype=rasterio.ubyte)
|
||||
self.assertRaises(ValueError, s.read_band, 0, a)
|
||||
|
||||
|
||||
89
rasterio/tests/test_write.py
Normal file
89
rasterio/tests/test_write.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os.path
|
||||
import unittest
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import numpy
|
||||
|
||||
import rasterio
|
||||
|
||||
class WriterContextTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tempdir)
|
||||
def test_context(self):
|
||||
name = os.path.join(self.tempdir, "test_context.tif")
|
||||
with rasterio.open(
|
||||
name, 'w',
|
||||
driver='GTiff', width=100, height=100, count=1,
|
||||
dtype=rasterio.ubyte) as s:
|
||||
self.assertEqual(s.name, name)
|
||||
self.assertEqual(s.driver, 'GTiff')
|
||||
self.assertEqual(s.closed, False)
|
||||
self.assertEqual(s.count, 1)
|
||||
self.assertEqual(s.width, 100)
|
||||
self.assertEqual(s.height, 100)
|
||||
self.assertEqual(s.shape, (100, 100))
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
"<open RasterUpdater '%s' at %s>" % (name, hex(id(s))))
|
||||
self.assertEqual(s.closed, True)
|
||||
self.assertEqual(s.count, 1)
|
||||
self.assertEqual(s.width, 100)
|
||||
self.assertEqual(s.height, 100)
|
||||
self.assertEqual(s.shape, (100, 100))
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
"<closed RasterUpdater '%s' at %s>" % (name, hex(id(s))))
|
||||
info = subprocess.check_output(["gdalinfo", name])
|
||||
self.assert_("GTiff" in info)
|
||||
self.assert_(
|
||||
"Size is 100, 100" in info)
|
||||
self.assert_(
|
||||
"Band 1 Block=100x81 Type=Byte, ColorInterp=Gray" in info)
|
||||
def test_write_ubyte(self):
|
||||
name = os.path.join(self.tempdir, "test_write_ubyte.tif")
|
||||
a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127
|
||||
with rasterio.open(
|
||||
name, 'w',
|
||||
driver='GTiff', width=100, height=100, count=1,
|
||||
dtype=a.dtype) as s:
|
||||
s.write_band(0, a)
|
||||
info = subprocess.check_output(["gdalinfo", "-stats", name])
|
||||
self.assert_(
|
||||
"Minimum=127.000, Maximum=127.000, "
|
||||
"Mean=127.000, StdDev=0.000" in info,
|
||||
info)
|
||||
def test_write_float(self):
|
||||
name = os.path.join(self.tempdir, "test_write_float.tif")
|
||||
a = numpy.ones((100, 100), dtype=rasterio.float32) * 42.0
|
||||
with rasterio.open(
|
||||
name, 'w',
|
||||
driver='GTiff', width=100, height=100, count=2,
|
||||
dtype=rasterio.float32) as s:
|
||||
self.assertEqual(s.dtypes, [rasterio.float32]*2)
|
||||
s.write_band(0, a)
|
||||
s.write_band(1, a)
|
||||
info = subprocess.check_output(["gdalinfo", "-stats", name])
|
||||
self.assert_(
|
||||
"Minimum=42.000, Maximum=42.000, "
|
||||
"Mean=42.000, StdDev=0.000" in info,
|
||||
info)
|
||||
def test_write_crs_transform(self):
|
||||
name = os.path.join(self.tempdir, "test_write_crs_transform.tif")
|
||||
a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127
|
||||
with rasterio.open(
|
||||
name, 'w',
|
||||
driver='GTiff', width=100, height=100, count=1,
|
||||
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
|
||||
'proj': 'utm', 'zone': 18},
|
||||
transform=[101985.0, 300.0379266750948, 0.0,
|
||||
2826915.0, 0.0, -300.041782729805],
|
||||
dtype=rasterio.ubyte) as s:
|
||||
s.write_band(0, a)
|
||||
info = subprocess.check_output(["gdalinfo", name])
|
||||
self.assert_('PROJCS["UTM Zone 18, Northern Hemisphere",' in info)
|
||||
self.assert_("(300.037926675094809,-300.041782729804993)" in info)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user