diff --git a/README.md b/README.md index 8d9d2dc1..a7276f7f 100644 --- a/README.md +++ b/README.md @@ -15,27 +15,35 @@ Example Here's an example of the features rasterio aims to provide. import rasterio + import subprocess # Read raster bands directly to Numpy arrays. with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src: - r = src.read_band(0).astype(rasterio.float32) - g = src.read_band(1).astype(rasterio.float32) - b = src.read_band(2).astype(rasterio.float32) + r = src.read_band(0) + g = src.read_band(1) + b = src.read_band(2) + assert [b.dtype.type for b in (r, g, b)] == src.dtypes - # Combine arrays using the 'add' ufunc and then convert back to btyes. + # Combine arrays using the 'add' ufunc. Expecting that the sum will exceed the + # 8-bit integer range, I convert to float32. + r = r.astype(rasterio.float32) + g = g.astype(rasterio.float32) + b = b.astype(rasterio.float32) total = (r + g + b)/3.0 - total = total.astype(rasterio.ubyte) - # Write the product as a raster band to a new file. For keyword arguments, - # we use meta attributes of the source file, but change the band count to 1. + # Write the product as a raster band to a new 8-bit file. For keyword + # arguments, we start with the meta attributes of the source file, but then + # change the band count to 1, set the dtype to uint8, and specify LZW + # compression. with rasterio.open( '/tmp/total.tif', 'w', - dtype=total.dtype, - **dict(src.meta, **{'count':1}) + **dict( + src.meta, + **{'dtype': rasterio.uint8, 'count':1, 'compress': 'lzw'}) ) as dst: - dst.write_band(0, total) + dst.write_band(0, total.astype(rasterio.uint8)) - import subprocess + # Dump out gdalinfo's report card. info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif']) print(info) diff --git a/examples/total.py b/examples/total.py index 7a471dbd..eae095b6 100644 --- a/examples/total.py +++ b/examples/total.py @@ -1,25 +1,33 @@ import rasterio +import subprocess # Read raster bands directly to Numpy arrays. with rasterio.open('rasterio/tests/data/RGB.byte.tif') as src: - r = src.read_band(0).astype(rasterio.float32) - g = src.read_band(1).astype(rasterio.float32) - b = src.read_band(2).astype(rasterio.float32) + r = src.read_band(0) + g = src.read_band(1) + b = src.read_band(2) + assert [b.dtype.type for b in (r, g, b)] == src.dtypes -# Combine arrays using the 'add' ufunc and then convert back to btyes. +# Combine arrays using the 'add' ufunc. Expecting that the sum will exceed the +# 8-bit integer range, I convert to float32. +r = r.astype(rasterio.float32) +g = g.astype(rasterio.float32) +b = b.astype(rasterio.float32) total = (r + g + b)/3.0 -total = total.astype(rasterio.ubyte) -# Write the product as a raster band to a new file. For keyword arguments, -# we use meta attributes of the source file, but change the band count to 1. +# Write the product as a raster band to a new 8-bit file. For keyword +# arguments, we start with the meta attributes of the source file, but then +# change the band count to 1, set the dtype to uint8, and specify LZW +# compression. with rasterio.open( '/tmp/total.tif', 'w', - dtype=total.dtype, - **dict(src.meta, **{'count':1}) + **dict( + src.meta, + **{'dtype': rasterio.uint8, 'count':1, 'compress': 'lzw'}) ) as dst: - dst.write_band(0, total) + dst.write_band(0, total.astype(rasterio.uint8)) -import subprocess +# Dump out gdalinfo's report card. info = subprocess.check_output(['gdalinfo', '-stats', '/tmp/total.tif']) print(info) diff --git a/rasterio/__init__.py b/rasterio/__init__.py index 85f8eda2..dce38c7e 100644 --- a/rasterio/__init__.py +++ b/rasterio/__init__.py @@ -15,7 +15,8 @@ def open( width=None, height=None, count=None, dtype=None, - crs=None, transform=None): + crs=None, transform=None, + **kwargs): """Open file at ``path`` in ``mode`` "r" (read), "r+" (read/write), or "w" (write) and return a ``Reader`` or ``Updater`` object. @@ -44,6 +45,9 @@ def open( Item 3: top left y value Item 4: rotation, 0 if the image is "north up" Item 5: N-S pixel resolution (usually a negative number) + + Finally, additional kwargs are passed to GDAL as driver-specific + dataset creation parameters. """ if not isinstance(path, string_types): raise TypeError("invalid path: %r" % path) @@ -64,7 +68,8 @@ def open( s = RasterUpdater( path, mode, driver, width, height, count, - crs, transform, dtype) + crs, transform, dtype, + **kwargs) else: raise ValueError( "mode string must be one of 'r', 'r+', or 'w', not %s" % mode) diff --git a/rasterio/_gdal.pxd b/rasterio/_gdal.pxd index d2ca695d..6cd752c5 100644 --- a/rasterio/_gdal.pxd +++ b/rasterio/_gdal.pxd @@ -6,7 +6,7 @@ cdef extern from "cpl_conv.h": void CPLSetThreadLocalConfigOption (char *key, char *val) cdef extern from "cpl_string.h": - char ** CSLSetNameValue (char **list, char *name, char *value) + char ** CSLSetNameValue (char **list, char *name, char *val) void CSLDestroy (char **list) cdef extern from "ogr_srs_api.h": diff --git a/rasterio/_io.pyx b/rasterio/_io.pyx index e948c141..b686c8ea 100644 --- a/rasterio/_io.pyx +++ b/rasterio/_io.pyx @@ -303,12 +303,13 @@ cdef class RasterReader: cdef class RasterUpdater(RasterReader): # Read-write access to raster data and metadata. # TODO: the r+ mode. - cdef readonly object _init_dtype + cdef readonly object _init_dtype, _options def __init__( self, path, mode, driver=None, width=None, height=None, count=None, - crs=None, transform=None, dtype=None): + crs=None, transform=None, dtype=None, + **kwargs): self.name = path self.mode = mode self.driver = driver @@ -322,6 +323,7 @@ cdef class RasterUpdater(RasterReader): self._transform = transform self._closed = True self._dtypes = [] + self._options = kwargs.copy() def __repr__(self): return "<%s RasterUpdater '%s' at %s>" % ( @@ -330,11 +332,14 @@ cdef class RasterUpdater(RasterReader): hex(id(self))) def start(self): - cdef const char *drv_name - cdef void *drv + cdef const char *drv_name = NULL + cdef char **options = NULL + cdef char *key_c, *val_c = NULL + cdef void *drv = NULL if not registered: register() cdef const char *fname = self.name + if self.mode == 'w': # Delete existing file, create. if os.path.exists(self.name): @@ -354,22 +359,38 @@ cdef class RasterUpdater(RasterReader): gdal_dtype = dtypes.dtype_rev.get(tp) else: gdal_dtype = dtypes.dtype_rev.get(self._init_dtype) + + # Creation options + for k, v in self._options.items(): + k, v = k.upper(), v.upper() + key_b = k.encode('utf-8') + val_b = v.encode('utf-8') + key_c = key_b + val_c = val_b + options = _gdal.CSLSetNameValue(options, key_c, val_c) + log.debug("Option: %r\n", (k, v)) + self._hds = _gdal.GDALCreate( drv, fname, self.width, self.height, self._count, - gdal_dtype, - NULL) + gdal_dtype, options) + if self._transform: self.write_transform(self._transform) if self._crs: self.write_crs(self._crs) + elif self.mode == 'a': self._hds = _gdal.GDALOpen(fname, 1) + self._count = _gdal.GDALGetRasterCount(self._hds) self.width = _gdal.GDALGetRasterXSize(self._hds) self.height = _gdal.GDALGetRasterYSize(self._hds) self.shape = (self.height, self.width) self._closed = False + if options: + _gdal.CSLDestroy(options) + def get_crs(self): if not self._crs: self._crs = self.read_crs() diff --git a/rasterio/tests/test_write.py b/rasterio/tests/test_write.py index 0ca84608..63c61fb5 100644 --- a/rasterio/tests/test_write.py +++ b/rasterio/tests/test_write.py @@ -1,13 +1,17 @@ +import logging import os.path import unittest import shutil import subprocess +import sys import tempfile import numpy import rasterio +logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) + class WriterContextTest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() @@ -98,4 +102,16 @@ class WriterContextTest(unittest.TestCase): "Minimum=127.000, Maximum=127.000, " "Mean=127.000, StdDev=0.000" in info, info) + def test_write_lzw(self): + name = os.path.join(self.tempdir, "test_write_lzw.tif") + a = numpy.ones((100, 100), dtype=rasterio.ubyte) * 127 + with rasterio.open( + name, 'w', + driver='GTiff', + width=100, height=100, count=1, + dtype=a.dtype, + compress='LZW') as s: + s.write_band(0, a) + info = subprocess.check_output(["gdalinfo", name]) + self.assert_("LZW" in info, info) diff --git a/rasterio/tests/test_write.pyc b/rasterio/tests/test_write.pyc index e708af57..fca33a51 100644 Binary files a/rasterio/tests/test_write.pyc and b/rasterio/tests/test_write.pyc differ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..567c8ee7 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,6 @@ +[nosetests] +tests=rasterio/tests +nocapture=True +verbosity=3 +logging-filter=rasterio +logging-level=DEBUG