* Resolve #1520 and #1521

* Remove commented code, restore credential fetch in __init__
This commit is contained in:
Sean Gillies 2018-10-25 16:20:29 -06:00 committed by GitHub
parent 98126f2abc
commit 1551cabd77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 119 additions and 53 deletions

View File

@ -6,6 +6,11 @@ Changes
Bug fixes:
- Delegate test of the environment for existing session credentials to the
session class to generalize credentialization of GDAL to cloud providers
other than AWS (#1520). The env.hascreds function is no longer used in
Rasterio and has been marked as deprecated.
- Switch to use of botocore Credentials.get_frozen_credentials (#1521).
- Numpy masked arrays with the normal Numpy mask sense (True == invalid) are
now supported as input for feature.shapes(). The mask keyword argument of the
function keeps to the GDAL sense of masks (nonzero == invalid) and the

View File

@ -8,16 +8,11 @@ import re
import threading
import warnings
import rasterio
from rasterio._env import (
GDALEnv, del_gdal_config, get_gdal_config, set_gdal_config)
from rasterio._env import GDALEnv, get_gdal_config, set_gdal_config
from rasterio.compat import string_types, getargspec
from rasterio.dtypes import check_dtype
from rasterio.errors import (
EnvError, GDALVersionError, RasterioDeprecationWarning)
from rasterio.path import parse_path, UnparsedPath, ParsedPath
from rasterio.session import Session, AWSSession, DummySession
from rasterio.transform import guard_transform
class ThreadEnv(threading.local):
@ -226,16 +221,6 @@ class Env(object):
options.update(**kwargs)
return Env(*args, **options)
@property
def is_credentialized(self):
"""Test for existence of cloud credentials
Returns
-------
bool
"""
return hascreds()
def credentialize(self):
"""Get credentials and configure GDAL
@ -247,7 +232,7 @@ class Env(object):
None
"""
if hascreds():
if self.session.hascreds(getenv()):
pass
else:
cred_opts = self.session.get_credential_options()
@ -338,6 +323,7 @@ def setenv(**options):
def hascreds():
warnings.warn("Please use Env.session.hascreds() instead", RasterioDeprecationWarning)
return local._env is not None and all(key in local._env.get_config_options() for key in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'])
@ -397,12 +383,16 @@ def ensure_env_with_credentials(f):
else:
env_ctor = Env.from_defaults
if hascreds():
session = DummySession()
elif isinstance(args[0], str):
session = Session.from_path(args[0])
if isinstance(args[0], str):
session_cls = Session.cls_from_path(args[0])
if local._env and session_cls.hascreds(getenv()):
session_cls = DummySession
session = session_cls()
else:
session = Session.from_path(None)
session = DummySession()
with env_ctor(session=session):
return f(*args, **kwds)

View File

@ -1,7 +1,7 @@
"""Abstraction for sessions in various clouds."""
from rasterio.path import parse_path, UnparsedPath, ParsedPath
from rasterio.path import parse_path, UnparsedPath
class Session(object):
@ -18,6 +18,24 @@ class Session(object):
"""
@classmethod
def hascreds(cls, config):
"""Determine if the given configuration has proper credentials
Parameters
----------
cls : class
A Session class.
config : dict
GDAL configuration as a dict.
Returns
-------
bool
"""
return NotImplementedError
def get_credential_options(self):
"""Get credentials as GDAL configuration options
@ -49,6 +67,38 @@ class Session(object):
else:
return cls(session)
@staticmethod
def cls_from_path(path):
"""Find the session class suited to the data at `path`.
Parameters
----------
path : str
A dataset path or identifier.
Returns
-------
class
"""
if not path:
return DummySession
path = parse_path(path)
if isinstance(path, UnparsedPath) or path.is_local:
return DummySession
elif path.scheme == "s3" or "amazonaws.com" in path.path:
return AWSSession
# This factory can be extended to other cloud providers here.
# elif path.scheme == "cumulonimbus": # for example.
# return CumulonimbusSession(*args, **kwargs)
else:
return DummySession
@staticmethod
def from_path(path, *args, **kwargs):
"""Create a session object suited to the data at `path`.
@ -67,23 +117,7 @@ class Session(object):
Session
"""
if not path:
return DummySession()
path = parse_path(path)
if isinstance(path, UnparsedPath) or path.is_local:
return DummySession()
elif path.scheme == "s3" or "amazonaws.com" in path.path:
return AWSSession(*args, **kwargs)
# This factory can be extended to other cloud providers here.
# elif path.scheme == "cumulonimbus": # for example.
# return CumulonimbusSession(*args, **kwargs)
else:
return DummySession()
return Session.cls_from_path(path)(*args, **kwargs)
class DummySession(Session):
@ -100,6 +134,24 @@ class DummySession(Session):
self._session = None
self.credentials = {}
@classmethod
def hascreds(cls, config):
"""Determine if the given configuration has proper credentials
Parameters
----------
cls : class
A Session class.
config : dict
GDAL configuration as a dict.
Returns
-------
bool
"""
return True
def get_credential_options(self):
"""Get credentials as GDAL configuration options
@ -157,22 +209,41 @@ class AWSSession(Session):
self.unsigned = aws_unsigned
self._creds = self._session._session.get_credentials()
@classmethod
def hascreds(cls, config):
"""Determine if the given configuration has proper credentials
Parameters
----------
cls : class
A Session class.
config : dict
GDAL configuration as a dict.
Returns
-------
bool
"""
return 'AWS_ACCESS_KEY_ID' in config and 'AWS_SECRET_ACCESS_KEY' in config
@property
def credentials(self):
"""The session credentials as a dict"""
creds = {}
res = {}
if self._creds:
if self._creds.access_key: # pragma: no branch
creds['aws_access_key_id'] = self._creds.access_key
if self._creds.secret_key: # pragma: no branch
creds['aws_secret_access_key'] = self._creds.secret_key
if self._creds.token:
creds['aws_session_token'] = self._creds.token
frozen_creds = self._creds.get_frozen_credentials()
if frozen_creds.access_key: # pragma: no branch
res['aws_access_key_id'] = frozen_creds.access_key
if frozen_creds.secret_key: # pragma: no branch
res['aws_secret_access_key'] = frozen_creds.secret_key
if frozen_creds.token:
res['aws_session_token'] = frozen_creds.token
if self._session.region_name:
creds['aws_region'] = self._session.region_name
res['aws_region'] = self._session.region_name
if self.requester_pays:
creds['aws_request_payer'] = 'requester'
return creds
res['aws_request_payer'] = 'requester'
return res
def get_credential_options(self):
"""Get credentials as GDAL configuration options

View File

@ -166,9 +166,9 @@ def test_aws_session(gdalenv):
aws_access_key_id='id', aws_secret_access_key='key',
aws_session_token='token', region_name='null-island-1')
with rasterio.env.Env(session=aws_session) as s:
assert s.session._creds.access_key == 'id'
assert s.session._creds.secret_key == 'key'
assert s.session._creds.token == 'token'
assert s.session._session.get_credentials().get_frozen_credentials().access_key == 'id'
assert s.session._session.get_credentials().get_frozen_credentials().secret_key == 'key'
assert s.session._session.get_credentials().get_frozen_credentials().token == 'token'
assert s.session._session.region_name == 'null-island-1'