mirror of
https://github.com/google/earthengine-api.git
synced 2025-12-08 19:26:12 +00:00
1919 lines
65 KiB
Python
1919 lines
65 KiB
Python
#!/usr/bin/env python3
|
|
"""Commands supported by the Earth Engine command line interface.
|
|
|
|
Each command is implemented by extending the Command class. Each class
|
|
defines the supported positional and optional arguments, as well as
|
|
the actions to be taken when the command is executed.
|
|
"""
|
|
|
|
import argparse
|
|
import calendar
|
|
import collections
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
from typing import Any, Dict, List, Sequence, Tuple, Union
|
|
import urllib.parse
|
|
|
|
# Prevent TensorFlow from logging anything at the native level.
|
|
# pylint: disable=g-import-not-at-top
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
|
|
# Suppress non-error logs while TF initializes
|
|
old_level = logging.getLogger().level
|
|
logging.getLogger().setLevel(logging.ERROR)
|
|
|
|
TENSORFLOW_INSTALLED = False
|
|
|
|
# pylint: disable=g-import-not-at-top
|
|
try:
|
|
import tensorflow.compat.v1 as tf
|
|
from tensorflow.compat.v1.saved_model import utils as saved_model_utils
|
|
from tensorflow.compat.v1.saved_model import signature_constants
|
|
from tensorflow.compat.v1.saved_model import signature_def_utils
|
|
# This triggers a warning about disable_resource_variables
|
|
tf.disable_v2_behavior()
|
|
# Prevent TensorFlow from logging anything at the python level.
|
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
|
|
|
TENSORFLOW_INSTALLED = True
|
|
except ImportError:
|
|
pass
|
|
except TypeError:
|
|
# The installed version of the protobuf package is incompatible with
|
|
# Tensorflow. A type error is thrown when trying to generate proto
|
|
# descriptors. Reinstalling Tensorflow should fix any dep versioning issues.
|
|
pass
|
|
finally:
|
|
logging.getLogger().setLevel(old_level)
|
|
|
|
TENSORFLOW_ADDONS_INSTALLED = False
|
|
# pylint: disable=g-import-not-at-top
|
|
if TENSORFLOW_INSTALLED:
|
|
try:
|
|
# This import is enough to register TFA ops though isn't directly used
|
|
# (for now).
|
|
# pylint: disable=unused-import
|
|
import tensorflow_addons as tfa
|
|
tfa.register_all(custom_kernels=False) # pytype: disable=module-attr
|
|
TENSORFLOW_ADDONS_INSTALLED = True
|
|
except ImportError:
|
|
pass
|
|
except AttributeError:
|
|
# This can be thrown by "tfa.register_all()" which means the
|
|
# tensorflow_addons version is registering ops the old way, i.e.
|
|
# automatically at import time. If this is the case, we've actually
|
|
# successfully registered TFA.
|
|
TENSORFLOW_ADDONS_INSTALLED = True
|
|
|
|
# pylint: disable=g-import-not-at-top, g-bad-import-order
|
|
import ee
|
|
from ee.cli import utils
|
|
|
|
# Constants used in ACLs.
|
|
ALL_USERS = 'allUsers'
|
|
ALL_USERS_CAN_READ = 'all_users_can_read'
|
|
READERS = 'readers'
|
|
WRITERS = 'writers'
|
|
|
|
# Constants used in setting metadata properties.
|
|
TYPE_DATE = 'date'
|
|
TYPE_NUMBER = 'number'
|
|
TYPE_STRING = 'string'
|
|
SYSTEM_TIME_START = 'system:time_start'
|
|
SYSTEM_TIME_END = 'system:time_end'
|
|
|
|
# A regex that parses properties of the form "[(type)]name=value". The
|
|
# second, third, and fourth group are type, name, and number, respectively.
|
|
PROPERTY_RE = re.compile(r'(\(([^\)]*)\))?([^=]+)=(.*)')
|
|
|
|
# Translate internal task type identifiers to user-friendly strings that
|
|
# are consistent with the language in the API and docs.
|
|
TASK_TYPES = {
|
|
'EXPORT_FEATURES': 'Export.table',
|
|
'EXPORT_IMAGE': 'Export.image',
|
|
'EXPORT_TILES': 'Export.map',
|
|
'EXPORT_VIDEO': 'Export.video',
|
|
'INGEST': 'Upload',
|
|
'INGEST_IMAGE': 'Upload',
|
|
'INGEST_TABLE': 'Upload',
|
|
}
|
|
|
|
TF_RECORD_EXTENSIONS = ['.tfrecord', 'tfrecord.gz']
|
|
|
|
# Maximum size of objects in a SavedModel directory that we're willing to
|
|
# download from GCS.
|
|
SAVED_MODEL_MAX_SIZE = 400 * 1024 * 1024
|
|
|
|
# Default path to SavedModel variables.
|
|
DEFAULT_VARIABLES_PREFIX = '/variables/variables'
|
|
|
|
|
|
def _add_wait_arg(parser: argparse.ArgumentParser) -> None:
|
|
parser.add_argument(
|
|
'--wait', '-w', nargs='?', default=-1, type=int, const=sys.maxsize,
|
|
help=('Wait for the task to finish,'
|
|
' or timeout after the specified number of seconds.'
|
|
' Without this flag, the command just starts an export'
|
|
' task in the background, and returns immediately.'))
|
|
|
|
|
|
def _add_overwrite_arg(parser: argparse.ArgumentParser) -> None:
|
|
parser.add_argument(
|
|
'--force', '-f', action='store_true',
|
|
help='Overwrite any existing version of the asset.')
|
|
|
|
|
|
def _upload(
|
|
args: argparse.Namespace, request: Dict[str, Any], ingestion_function: Any
|
|
) -> None:
|
|
if 0 <= args.wait < 10:
|
|
raise ee.EEException('Wait time should be at least 10 seconds.')
|
|
request_id = ee.data.newTaskId()[0]
|
|
task_id = ingestion_function(request_id, request, args.force)['id']
|
|
print('Started upload task with ID: %s' % task_id)
|
|
if args.wait >= 0:
|
|
print('Waiting for the upload task to complete...')
|
|
utils.wait_for_task(task_id, args.wait)
|
|
|
|
|
|
# Argument types
|
|
def _comma_separated_strings(string: str) -> List[str]:
|
|
"""Parses an input consisting of comma-separated strings."""
|
|
error_msg = 'Argument should be a comma-separated list of strings: {}'
|
|
values = string.split(',')
|
|
if not values:
|
|
raise argparse.ArgumentTypeError(error_msg.format(string))
|
|
return values
|
|
|
|
|
|
def _comma_separated_numbers(string: str) -> List[float]:
|
|
"""Parses an input consisting of comma-separated numbers."""
|
|
error_msg = 'Argument should be a comma-separated list of numbers: {}'
|
|
values = string.split(',')
|
|
if not values:
|
|
raise argparse.ArgumentTypeError(error_msg.format(string))
|
|
numbervalues = []
|
|
for value in values:
|
|
try:
|
|
numbervalues.append(int(value))
|
|
except ValueError:
|
|
try:
|
|
numbervalues.append(float(value))
|
|
except ValueError:
|
|
# pylint: disable-next=raise-missing-from
|
|
raise argparse.ArgumentTypeError(error_msg.format(string))
|
|
return numbervalues
|
|
|
|
|
|
def _comma_separated_pyramiding_policies(string: str) -> List[str]:
|
|
"""Parses an input consisting of comma-separated pyramiding policies."""
|
|
error_msg = ('Argument should be a comma-separated list of: '
|
|
'{{"mean", "sample", "min", "max", "mode"}}: {}')
|
|
values = string.split(',')
|
|
if not values:
|
|
raise argparse.ArgumentTypeError(error_msg.format(string))
|
|
redvalues = []
|
|
for value in values:
|
|
value = value.upper()
|
|
if value not in {'MEAN', 'SAMPLE', 'MIN', 'MAX', 'MODE', 'MEDIAN'}:
|
|
raise argparse.ArgumentTypeError(error_msg.format(string))
|
|
redvalues.append(value)
|
|
return redvalues
|
|
|
|
|
|
def _decode_number(string: str) -> float:
|
|
"""Decodes a number from a command line argument."""
|
|
try:
|
|
return float(string)
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError( # pylint: disable=raise-missing-from
|
|
'Invalid value for property of type "number": "%s".' % string)
|
|
|
|
|
|
def _timestamp_ms_for_datetime(datetime_obj: datetime.datetime) -> int:
|
|
"""Returns time since the epoch in ms for the given UTC datetime object."""
|
|
return (
|
|
int(calendar.timegm(datetime_obj.timetuple()) * 1000) +
|
|
datetime_obj.microsecond // 1000)
|
|
|
|
|
|
def _cloud_timestamp_for_timestamp_ms(timestamp_ms: float) -> str:
|
|
"""Returns a Cloud-formatted date for the given millisecond timestamp."""
|
|
# Desired format is like '2003-09-07T19:30:12.345Z'
|
|
return datetime.datetime.utcfromtimestamp(
|
|
timestamp_ms / 1000.0).isoformat() + 'Z'
|
|
|
|
|
|
def _parse_millis(millis: float) -> datetime.datetime:
|
|
return datetime.datetime.fromtimestamp(millis / 1000)
|
|
|
|
|
|
def _decode_date(string: str) -> Union[float, str]:
|
|
"""Decodes a date from a command line argument, returning msec since epoch".
|
|
|
|
Args:
|
|
string: See AssetSetCommand class comment for the allowable
|
|
date formats.
|
|
|
|
Returns:
|
|
long, ms since epoch, or '' if the input is empty.
|
|
|
|
Raises:
|
|
argparse.ArgumentTypeError: if string does not conform to a legal
|
|
date format.
|
|
"""
|
|
if not string:
|
|
return ''
|
|
|
|
try:
|
|
return int(string)
|
|
except ValueError:
|
|
date_formats = ['%Y-%m-%d',
|
|
'%Y-%m-%dT%H:%M:%S',
|
|
'%Y-%m-%dT%H:%M:%S.%f']
|
|
for date_format in date_formats:
|
|
try:
|
|
dt = datetime.datetime.strptime(string, date_format)
|
|
return _timestamp_ms_for_datetime(dt)
|
|
except ValueError:
|
|
continue
|
|
raise argparse.ArgumentTypeError(
|
|
'Invalid value for property of type "date": "%s".' % string)
|
|
|
|
|
|
def _decode_property(string: str) -> Tuple[str, Any]:
|
|
"""Decodes a general key-value property from a command-line argument.
|
|
|
|
Args:
|
|
string: The string must have the form name=value or (type)name=value, where
|
|
type is one of 'number', 'string', or 'date'. The value format for dates
|
|
is YYYY-MM-DD[THH:MM:SS[.MS]]. The value 'null' is special: it evaluates
|
|
to None unless it is cast to a string of 'null'.
|
|
|
|
Returns:
|
|
a tuple representing the property in the format (name, value)
|
|
|
|
Raises:
|
|
argparse.ArgumentTypeError: if the flag value could not be decoded or if
|
|
the type is not recognized
|
|
"""
|
|
|
|
m = PROPERTY_RE.match(string)
|
|
if not m:
|
|
raise argparse.ArgumentTypeError(
|
|
'Invalid property: "%s". Must have the form "name=value" or '
|
|
'"(type)name=value".' % string)
|
|
_, type_str, name, value_str = m.groups()
|
|
if value_str == 'null' and type_str != TYPE_STRING:
|
|
return (name, None)
|
|
if type_str is None:
|
|
# Guess numeric types automatically.
|
|
try:
|
|
value = _decode_number(value_str)
|
|
except argparse.ArgumentTypeError:
|
|
value = value_str
|
|
elif type_str == TYPE_DATE:
|
|
value = _decode_date(value_str)
|
|
elif type_str == TYPE_NUMBER:
|
|
value = _decode_number(value_str)
|
|
elif type_str == TYPE_STRING:
|
|
value = value_str
|
|
else:
|
|
raise argparse.ArgumentTypeError(
|
|
'Unrecognized property type name: "%s". Expected one of "string", '
|
|
'"number", or "date".' % type_str)
|
|
return (name, value)
|
|
|
|
|
|
def _add_property_flags(parser):
|
|
"""Adds command line flags related to metadata properties to a parser."""
|
|
parser.add_argument(
|
|
'--property', '-p',
|
|
help='A property to set, in the form [(type)]name=value. If no type '
|
|
'is specified the type will be "number" if the value is numeric and '
|
|
'"string" otherwise. May be provided multiple times.',
|
|
action='append',
|
|
type=_decode_property)
|
|
parser.add_argument(
|
|
'--time_start', '-ts',
|
|
help='Sets the start time property to a number or date.',
|
|
type=_decode_date)
|
|
parser.add_argument(
|
|
'--time_end', '-te',
|
|
help='Sets the end time property to a number or date.',
|
|
type=_decode_date)
|
|
|
|
|
|
def _decode_property_flags(args):
|
|
"""Decodes metadata properties from args as a name->value dict."""
|
|
property_list = list(args.property or [])
|
|
names = [name for name, _ in property_list]
|
|
duplicates = [
|
|
name for name, count in collections.Counter(names).items() if count > 1]
|
|
if duplicates:
|
|
raise ee.EEException('Duplicate property name(s): %s.' % duplicates)
|
|
return dict(property_list)
|
|
|
|
|
|
def _check_valid_files(filenames: Sequence[str]) -> None:
|
|
"""Returns true if the given filenames are valid upload file URIs."""
|
|
for filename in filenames:
|
|
if not filename.startswith('gs://'):
|
|
raise ee.EEException('Invalid Cloud Storage URL: ' + filename)
|
|
|
|
|
|
def _pretty_print_json(json_obj):
|
|
"""Pretty-prints a JSON object to stdandard output."""
|
|
print(json.dumps(json_obj, sort_keys=True, indent=2, separators=(',', ': ')))
|
|
|
|
|
|
class Dispatcher:
|
|
"""Dispatches to a set of commands implemented as command classes."""
|
|
COMMANDS: List[Any]
|
|
command_dict: Dict[str, Any]
|
|
dest: str
|
|
name: str
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
self.command_dict = {}
|
|
self.dest = self.name + '_cmd'
|
|
subparsers = parser.add_subparsers(title='Commands', dest=self.dest)
|
|
subparsers.required = True # Needed for proper missing arg handling in 3.x
|
|
for command in self.COMMANDS:
|
|
command_help = None
|
|
if command.__doc__ and command.__doc__.splitlines():
|
|
command_help = command.__doc__.splitlines()[0]
|
|
subparser = subparsers.add_parser(
|
|
command.name,
|
|
description=command.__doc__,
|
|
help=command_help)
|
|
self.command_dict[command.name] = command(subparser)
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
self.command_dict[vars(args)[self.dest]].run(args, config)
|
|
|
|
|
|
class AuthenticateCommand:
|
|
"""Prompts the user to authorize access to Earth Engine via OAuth2.
|
|
|
|
Note that running this command in the default interactive mode within
|
|
JupyterLab with a bash magic command (i.e. "!earthengine authenticate") is
|
|
problematic (see https://github.com/ipython/ipython/issues/10499). To avoid
|
|
this issue, use the non-interactive mode
|
|
(i.e. "!earthengine authenticate --quiet").
|
|
"""
|
|
|
|
name = 'authenticate'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'--authorization-code',
|
|
help='Use this specified authorization code.')
|
|
parser.add_argument(
|
|
'--quiet',
|
|
action='store_true',
|
|
help='Do not prompt for input, and run gcloud in no-browser mode.')
|
|
parser.add_argument(
|
|
'--code-verifier',
|
|
help='PKCE verifier to prevent auth code stealing.')
|
|
parser.add_argument(
|
|
'--auth_mode',
|
|
help='One of: notebook - use notebook authenticator; gcloud - use'
|
|
' gcloud; appdefault - read GOOGLE_APPLICATION_CREDENTIALS;'
|
|
' localhost[:PORT] - use local browser')
|
|
parser.add_argument(
|
|
'--scopes', help='Optional comma-separated list of scopes.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Prompts for an auth code, requests a token and saves it."""
|
|
del config # Unused
|
|
|
|
# Filter for arguments relevant for ee.Authenticate()
|
|
args_auth = {x: vars(args)[x] for x in (
|
|
'authorization_code', 'quiet', 'code_verifier', 'auth_mode')}
|
|
if args.scopes:
|
|
args_auth['scopes'] = args.scopes.split(',')
|
|
ee.Authenticate(**args_auth)
|
|
|
|
|
|
class SetProjectCommand:
|
|
"""Sets the default user project to be used for all API calls."""
|
|
|
|
name = 'set_project'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('project', help='project id or number to use.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Saves the project to the config file."""
|
|
|
|
config_path = config.config_file
|
|
with open(config_path) as config_file_json:
|
|
config = json.load(config_file_json)
|
|
|
|
config['project'] = args.project
|
|
json.dump(config, open(config_path, 'w'))
|
|
print('Successfully saved project id')
|
|
|
|
|
|
class UnSetProjectCommand:
|
|
"""UnSets the default user project to be used for all API calls."""
|
|
|
|
name = 'unset_project'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
del parser # Unused.
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Saves the project to the config file."""
|
|
del args # Unused.
|
|
|
|
config_path = config.config_file
|
|
with open(config_path) as config_file_json:
|
|
config = json.load(config_file_json)
|
|
|
|
if 'project' in config:
|
|
del config['project']
|
|
json.dump(config, open(config_path, 'w'))
|
|
print('Successfully unset project id')
|
|
|
|
|
|
class AclChCommand:
|
|
"""Changes the access control list for an asset.
|
|
|
|
Each change specifies the email address of a user or group and,
|
|
for additions, one of R or W corresponding to the read or write
|
|
permissions to be granted, as in "user@domain.com:R". Use the
|
|
special name "allUsers" to change whether all users can read the
|
|
asset.
|
|
"""
|
|
|
|
name = 'ch'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('-u', action='append', metavar='user permission',
|
|
help='Add or modify a user\'s permission.')
|
|
parser.add_argument('-d', action='append', metavar='remove user',
|
|
help='Remove all permissions for a user.')
|
|
parser.add_argument('-g', action='append', metavar='group permission',
|
|
help='Add or modify a group\'s permission.')
|
|
parser.add_argument('-dg', action='append', metavar='remove group',
|
|
help='Remove all permissions for a user.')
|
|
parser.add_argument('asset_id', help='ID of the asset.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Performs an ACL update."""
|
|
config.ee_init()
|
|
permissions = self._parse_permissions(args)
|
|
acl = ee.data.getAssetAcl(args.asset_id)
|
|
self._apply_permissions(acl, permissions)
|
|
ee.data.setAssetAcl(args.asset_id, json.dumps(acl))
|
|
|
|
def _set_permission(self, permissions, grant, prefix):
|
|
"""Sets the permission for a given user/group."""
|
|
parts = grant.rsplit(':', 1)
|
|
if len(parts) != 2 or parts[1] not in ['R', 'W']:
|
|
raise ee.EEException('Invalid permission "%s".' % grant)
|
|
user, role = parts
|
|
prefixed_user = user
|
|
if not self._is_all_users(user):
|
|
prefixed_user = prefix + user
|
|
if prefixed_user in permissions:
|
|
raise ee.EEException('Multiple permission settings for "%s".' % user)
|
|
if self._is_all_users(user) and role == 'W':
|
|
raise ee.EEException('Cannot grant write permissions to all users.')
|
|
permissions[prefixed_user] = role
|
|
|
|
def _remove_permission(self, permissions, user, prefix):
|
|
"""Removes permissions for a given user/group."""
|
|
prefixed_user = user
|
|
if not self._is_all_users(user):
|
|
prefixed_user = prefix + user
|
|
if prefixed_user in permissions:
|
|
raise ee.EEException('Multiple permission settings for "%s".' % user)
|
|
permissions[prefixed_user] = 'D'
|
|
|
|
def _user_account_type(self, user):
|
|
"""Returns the appropriate account type for a user email."""
|
|
|
|
# Here 'user' ends with ':R', ':W', or ':D', so we extract
|
|
# just the username.
|
|
if user.split(':')[0].endswith('.gserviceaccount.com'):
|
|
return 'serviceAccount:'
|
|
else:
|
|
return 'user:'
|
|
|
|
def _parse_permissions(self, args):
|
|
"""Decodes and sanity-checks the permissions in the arguments."""
|
|
# A dictionary mapping from user ids to one of 'R', 'W', or 'D'.
|
|
permissions = {}
|
|
if args.u:
|
|
for user in args.u:
|
|
self._set_permission(permissions, user, self._user_account_type(user))
|
|
if args.d:
|
|
for user in args.d:
|
|
self._remove_permission(
|
|
permissions, user, self._user_account_type(user))
|
|
if args.g:
|
|
for group in args.g:
|
|
self._set_permission(permissions, group, 'group:')
|
|
if args.dg:
|
|
for group in args.dg:
|
|
self._remove_permission(permissions, group, 'group:')
|
|
return permissions
|
|
|
|
def _apply_permissions(self, acl, permissions) -> None:
|
|
"""Applies the given permission edits to the given acl."""
|
|
for user, role in permissions.items():
|
|
if self._is_all_users(user):
|
|
acl[ALL_USERS_CAN_READ] = (role == 'R')
|
|
elif role == 'R':
|
|
if user not in acl[READERS]:
|
|
acl[READERS].append(user)
|
|
if user in acl[WRITERS]:
|
|
acl[WRITERS].remove(user)
|
|
elif role == 'W':
|
|
if user in acl[READERS]:
|
|
acl[READERS].remove(user)
|
|
if user not in acl[WRITERS]:
|
|
acl[WRITERS].append(user)
|
|
elif role == 'D':
|
|
if user in acl[READERS]:
|
|
acl[READERS].remove(user)
|
|
if user in acl[WRITERS]:
|
|
acl[WRITERS].remove(user)
|
|
|
|
def _is_all_users(self, user: str) -> bool:
|
|
"""Determines if a user name represents the special "all users" entity."""
|
|
# We previously used "AllUsers" as the magic string to denote that we wanted
|
|
# to apply some permission to everyone. However, Google Cloud convention for
|
|
# this concept is "allUsers". Because some people might be using one and
|
|
# some the other, we do a case-insensitive comparison.
|
|
return user.lower() == ALL_USERS.lower()
|
|
|
|
|
|
class AclGetCommand:
|
|
"""Prints the access control list for an asset."""
|
|
|
|
name = 'get'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('asset_id', help='ID of the asset.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
config.ee_init()
|
|
acl = ee.data.getAssetAcl(args.asset_id)
|
|
_pretty_print_json(acl)
|
|
|
|
|
|
class AclSetCommand:
|
|
"""Sets the access control list for an asset.
|
|
|
|
The ACL may be the name of a canned ACL, or it may be the path to a
|
|
file containing the output from "acl get". The recognized canned ACL
|
|
names are "private", indicating that no users other than the owner
|
|
have access, and "public", indicating that all users have read
|
|
access. It is currently not possible to modify the owner ACL using
|
|
this tool.
|
|
"""
|
|
|
|
name = 'set'
|
|
|
|
CANNED_ACLS = {
|
|
'private': {
|
|
READERS: [],
|
|
WRITERS: [],
|
|
ALL_USERS_CAN_READ: False,
|
|
},
|
|
'public': {
|
|
READERS: [],
|
|
WRITERS: [],
|
|
ALL_USERS_CAN_READ: True,
|
|
},
|
|
}
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('file_or_acl_name',
|
|
help='File path or canned ACL name.')
|
|
parser.add_argument('asset_id', help='ID of the asset.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Sets asset ACL to a canned ACL or one provided in a JSON file."""
|
|
config.ee_init()
|
|
if args.file_or_acl_name in list(self.CANNED_ACLS.keys()):
|
|
acl = self.CANNED_ACLS[args.file_or_acl_name]
|
|
else:
|
|
acl = json.load(open(args.file_or_acl_name))
|
|
ee.data.setAssetAcl(args.asset_id, json.dumps(acl))
|
|
|
|
|
|
class AclCommand(Dispatcher):
|
|
"""Prints or updates the access control list of the specified asset."""
|
|
|
|
name = 'acl'
|
|
|
|
COMMANDS = [
|
|
AclChCommand,
|
|
AclGetCommand,
|
|
AclSetCommand,
|
|
]
|
|
|
|
|
|
class AssetInfoCommand:
|
|
"""Prints metadata and other information about an Earth Engine asset."""
|
|
|
|
name = 'info'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('asset_id', help='ID of the asset to print.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
config.ee_init()
|
|
info = ee.data.getInfo(args.asset_id)
|
|
if info:
|
|
_pretty_print_json(info)
|
|
else:
|
|
raise ee.EEException(
|
|
'Asset does not exist or is not accessible: %s' % args.asset_id)
|
|
|
|
|
|
class AssetSetCommand:
|
|
"""Sets metadata properties of an Earth Engine asset.
|
|
|
|
Properties may be of type "string", "number", or "date". Dates must
|
|
be specified in the form YYYY-MM-DD[Thh:mm:ss[.ff]] in UTC and are
|
|
stored as numbers representing the number of milliseconds since the
|
|
Unix epoch (00:00:00 UTC on 1 January 1970).
|
|
|
|
To delete a property, set it to null without a type:
|
|
prop=null.
|
|
To set a property to the string value 'null', use the assignment
|
|
(string)prop4=null.
|
|
"""
|
|
|
|
name = 'set'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('asset_id', help='ID of the asset to update.')
|
|
_add_property_flags(parser)
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Runs the asset update."""
|
|
config.ee_init()
|
|
properties = _decode_property_flags(args)
|
|
if not properties and args.time_start is None and args.time_end is None:
|
|
raise ee.EEException('No properties specified.')
|
|
update_mask = [
|
|
'properties.' + property_name for property_name in properties
|
|
]
|
|
asset = {}
|
|
if properties:
|
|
asset['properties'] = {
|
|
k: v for k, v in properties.items() if v is not None
|
|
}
|
|
# args.time_start and .time_end could have any of three falsy values, with
|
|
# different meanings:
|
|
# None: the --time_start flag was not provided at all
|
|
# '': the --time_start flag was explicitly set to the empty string
|
|
# 0: the --time_start flag was explicitly set to midnight 1 Jan 1970.
|
|
# pylint:disable=g-explicit-bool-comparison
|
|
if args.time_start is not None:
|
|
update_mask.append('start_time')
|
|
if args.time_start != '':
|
|
asset['start_time'] = _cloud_timestamp_for_timestamp_ms(
|
|
args.time_start)
|
|
if args.time_end is not None:
|
|
update_mask.append('end_time')
|
|
if args.time_end != '':
|
|
asset['end_time'] = _cloud_timestamp_for_timestamp_ms(args.time_end)
|
|
# pylint:enable=g-explicit-bool-comparison
|
|
ee.data.updateAsset(args.asset_id, asset, update_mask)
|
|
return
|
|
|
|
|
|
class AssetCommand(Dispatcher):
|
|
"""Prints or updates metadata associated with an Earth Engine asset."""
|
|
|
|
name = 'asset'
|
|
|
|
COMMANDS = [
|
|
AssetInfoCommand,
|
|
AssetSetCommand,
|
|
]
|
|
|
|
|
|
class CopyCommand:
|
|
"""Creates a new Earth Engine asset as a copy of another asset."""
|
|
|
|
name = 'cp'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'source', help='Full path of the source asset.')
|
|
parser.add_argument(
|
|
'destination', help='Full path of the destination asset.')
|
|
_add_overwrite_arg(parser)
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Runs the asset copy."""
|
|
config.ee_init()
|
|
ee.data.copyAsset(
|
|
args.source,
|
|
args.destination,
|
|
args.force
|
|
)
|
|
|
|
|
|
class CreateCommandBase:
|
|
"""Base class for implementing Create subcommands."""
|
|
|
|
def __init__(self, parser, fragment, asset_type):
|
|
parser.add_argument(
|
|
'asset_id', nargs='+',
|
|
help='Full path of %s to create.' % fragment)
|
|
parser.add_argument(
|
|
'--parents', '-p', action='store_true',
|
|
help='Make parent folders as needed.')
|
|
self.asset_type = asset_type
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
config.ee_init()
|
|
ee.data.create_assets(args.asset_id, self.asset_type, args.parents)
|
|
|
|
|
|
class CreateCollectionCommand(CreateCommandBase):
|
|
"""Creates one or more image collections."""
|
|
|
|
name = 'collection'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
super().__init__(
|
|
parser, 'an image collection', ee.data.ASSET_TYPE_IMAGE_COLL)
|
|
|
|
|
|
class CreateFolderCommand(CreateCommandBase):
|
|
"""Creates one or more folders."""
|
|
|
|
name = 'folder'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
super().__init__(parser, 'a folder', ee.data.ASSET_TYPE_FOLDER)
|
|
|
|
|
|
class CreateCommand(Dispatcher):
|
|
"""Creates assets and folders."""
|
|
|
|
name = 'create'
|
|
|
|
COMMANDS = [
|
|
CreateCollectionCommand,
|
|
CreateFolderCommand,
|
|
]
|
|
|
|
|
|
class ListCommand:
|
|
"""Prints the contents of a folder or collection."""
|
|
|
|
name = 'ls'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'asset_id', nargs='*',
|
|
help='A folder or image collection to be inspected.')
|
|
parser.add_argument(
|
|
'--long_format',
|
|
'-l',
|
|
action='store_true',
|
|
help='Print output in long format.')
|
|
parser.add_argument(
|
|
'--max_items', '-m', default=-1, type=int,
|
|
help='Maximum number of items to list for each collection.')
|
|
parser.add_argument(
|
|
'--recursive',
|
|
'-r',
|
|
action='store_true',
|
|
help='List folders recursively.')
|
|
parser.add_argument(
|
|
'--filter',
|
|
'-f',
|
|
default='',
|
|
type=str,
|
|
help=(
|
|
'Filter string to use on a collection. Accepts property names'
|
|
' "start_time", "end_time", "update_time", and "properties.foo"'
|
|
' (where "foo" is any user-defined property). Example filter'
|
|
' strings: properties.SCENE_ID="ABC";'
|
|
' start_time>"2023-02-03T00:00:00+00:00"'
|
|
),
|
|
)
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Runs the list command."""
|
|
config.ee_init()
|
|
if not args.asset_id:
|
|
roots = ee.data.getAssetRoots()
|
|
self._print_assets(roots, args.max_items, '', args.long_format,
|
|
args.recursive)
|
|
return
|
|
assets = args.asset_id
|
|
count = 0
|
|
for asset in assets:
|
|
if count > 0:
|
|
print()
|
|
self._list_asset_content(
|
|
asset, args.max_items, len(assets), args.long_format,
|
|
args.recursive, args.filter)
|
|
count += 1
|
|
|
|
def _print_assets(self, assets, max_items, indent, long_format, recursive):
|
|
"""Prints the listing of given assets."""
|
|
if not assets:
|
|
return
|
|
|
|
max_type_length = max([len(asset['type']) for asset in assets])
|
|
|
|
if recursive:
|
|
# fallback to max to include the string 'ImageCollection'
|
|
max_type_length = ee.data.MAX_TYPE_LENGTH
|
|
|
|
format_str = '%s{:%ds}{:s}' % (indent, max_type_length + 4)
|
|
for asset in assets:
|
|
if long_format:
|
|
# Example output:
|
|
# [Image] user/test/my_img
|
|
# [ImageCollection] user/test/my_coll
|
|
print(format_str.format('['+asset['type']+']', asset['id']))
|
|
|
|
else:
|
|
print(asset['id'])
|
|
|
|
if recursive and asset['type'] in (ee.data.ASSET_TYPE_FOLDER,
|
|
ee.data.ASSET_TYPE_FOLDER_CLOUD):
|
|
list_req = {'id': asset['id']}
|
|
children = ee.data.getList(list_req)
|
|
self._print_assets(children, max_items, indent, long_format, recursive)
|
|
|
|
def _list_asset_content(self, asset, max_items, total_assets, long_format,
|
|
recursive, filter_string):
|
|
"""Prints the contents of an asset and its children."""
|
|
try:
|
|
list_req = {'id': asset}
|
|
if max_items >= 0:
|
|
list_req['num'] = max_items
|
|
if filter_string:
|
|
list_req['filter'] = filter_string
|
|
children = ee.data.getList(list_req)
|
|
indent = ''
|
|
if total_assets > 1:
|
|
print('%s:' % asset)
|
|
indent = ' '
|
|
self._print_assets(children, max_items, indent, long_format, recursive)
|
|
except ee.EEException as e:
|
|
print(e)
|
|
|
|
|
|
class SizeCommand:
|
|
"""Prints the size and names of all items in a given folder or collection."""
|
|
|
|
name = 'du'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'asset_id',
|
|
nargs='*',
|
|
help='A folder or image collection to be inspected.')
|
|
parser.add_argument(
|
|
'--summarize', '-s', action='store_true',
|
|
help='Display only a total.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Runs the du command."""
|
|
config.ee_init()
|
|
|
|
# Select all available asset roots if no asset ids are given.
|
|
if not args.asset_id:
|
|
assets = ee.data.getAssetRoots()
|
|
else:
|
|
assets = [ee.data.getInfo(asset) for asset in args.asset_id]
|
|
|
|
# If args.summarize is True, list size+name for every leaf child asset,
|
|
# and show totals for non-leaf children.
|
|
# If args.summarize is False, print sizes of all children.
|
|
for index, asset in enumerate(assets):
|
|
if args.asset_id and not asset:
|
|
asset_id = args.asset_id[index]
|
|
print('Asset does not exist or is not accessible: %s' % asset_id)
|
|
continue
|
|
is_parent = asset['type'] in (
|
|
ee.data.ASSET_TYPE_FOLDER,
|
|
ee.data.ASSET_TYPE_IMAGE_COLL,
|
|
ee.data.ASSET_TYPE_FOLDER_CLOUD,
|
|
ee.data.ASSET_TYPE_IMAGE_COLL_CLOUD,
|
|
)
|
|
if not is_parent or args.summarize:
|
|
self._print_size(asset)
|
|
else:
|
|
children = ee.data.getList({'id': asset['id']})
|
|
if not children:
|
|
# A leaf asset
|
|
children = [asset]
|
|
for child in children:
|
|
self._print_size(child)
|
|
|
|
def _print_size(self, asset):
|
|
size = self._get_size(asset)
|
|
print('{:>16d} {}'.format(size, asset['id']))
|
|
|
|
def _get_size(self, asset):
|
|
"""Returns the size of the given asset in bytes."""
|
|
size_parsers = {
|
|
'Image': self._get_size_asset,
|
|
'Folder': self._get_size_folder,
|
|
'ImageCollection': self._get_size_image_collection,
|
|
'Table': self._get_size_asset,
|
|
'IMAGE': self._get_size_asset,
|
|
'FOLDER': self._get_size_folder,
|
|
'IMAGE_COLLECTION': self._get_size_image_collection,
|
|
'TABLE': self._get_size_asset,
|
|
}
|
|
|
|
if asset['type'] not in size_parsers:
|
|
raise ee.EEException(
|
|
'Cannot get size for asset type "%s"' % asset['type'])
|
|
|
|
return size_parsers[asset['type']](asset)
|
|
|
|
def _get_size_asset(self, asset):
|
|
info = ee.data.getInfo(asset['id'])
|
|
|
|
if 'sizeBytes' in info:
|
|
return int(info['sizeBytes'])
|
|
return info['properties']['system:asset_size']
|
|
|
|
def _get_size_folder(self, asset):
|
|
children = ee.data.getList({'id': asset['id']})
|
|
sizes = [self._get_size(child) for child in children]
|
|
|
|
return sum(sizes)
|
|
|
|
def _get_size_image_collection(self, asset):
|
|
images = ee.ImageCollection(asset['id'])
|
|
return images.aggregate_sum('system:asset_size').getInfo()
|
|
|
|
|
|
class MoveCommand:
|
|
"""Moves or renames an Earth Engine asset."""
|
|
|
|
name = 'mv'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'source', help='Full path of the source asset.')
|
|
parser.add_argument(
|
|
'destination', help='Full path of the destination asset.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
config.ee_init()
|
|
ee.data.renameAsset(args.source, args.destination)
|
|
|
|
|
|
class RmCommand:
|
|
"""Deletes the specified assets."""
|
|
|
|
name = 'rm'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'asset_id', nargs='+', help='Full path of an asset to delete.')
|
|
parser.add_argument(
|
|
'--recursive', '-r', action='store_true',
|
|
help='Recursively delete child assets.')
|
|
parser.add_argument(
|
|
'--dry_run', action='store_true',
|
|
help=('Perform a dry run of the delete operation. Does not '
|
|
'delete any assets.'))
|
|
parser.add_argument(
|
|
'--verbose', '-v', action='store_true',
|
|
help='Print the progress of the operation to the console.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
config.ee_init()
|
|
for asset in args.asset_id:
|
|
self._delete_asset(asset, args.recursive, args.verbose, args.dry_run)
|
|
|
|
def _delete_asset(self, asset_id, recursive, verbose, dry_run):
|
|
"""Attempts to delete the specified asset or asset collection."""
|
|
if recursive:
|
|
info = ee.data.getInfo(asset_id)
|
|
if info is None:
|
|
print('Asset does not exist or is not accessible: %s' % asset_id)
|
|
return
|
|
if info['type'] in (ee.data.ASSET_TYPE_FOLDER,
|
|
ee.data.ASSET_TYPE_IMAGE_COLL,
|
|
ee.data.ASSET_TYPE_FOLDER_CLOUD,
|
|
ee.data.ASSET_TYPE_IMAGE_COLL_CLOUD):
|
|
children = ee.data.getList({'id': asset_id})
|
|
for child in children:
|
|
self._delete_asset(child['id'], True, verbose, dry_run)
|
|
if dry_run:
|
|
print('[dry-run] Deleting asset: %s' % asset_id)
|
|
else:
|
|
if verbose:
|
|
print('Deleting asset: %s' % asset_id)
|
|
try:
|
|
ee.data.deleteAsset(asset_id)
|
|
except ee.EEException as e:
|
|
print('Failed to delete %s. %s' % (asset_id, e))
|
|
|
|
|
|
class TaskCancelCommand:
|
|
"""Cancels a running task."""
|
|
|
|
name = 'cancel'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'task_ids', nargs='+',
|
|
help='IDs of one or more tasks to cancel,'
|
|
' or `all` to cancel all tasks.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Cancels a running task."""
|
|
config.ee_init()
|
|
cancel_all = args.task_ids == ['all']
|
|
if cancel_all:
|
|
statuses = ee.data.getTaskList()
|
|
else:
|
|
statuses = ee.data.getTaskStatus(args.task_ids)
|
|
for status in statuses:
|
|
state = status['state']
|
|
task_id = status['id']
|
|
if state == 'UNKNOWN':
|
|
raise ee.EEException('Unknown task id "%s"' % task_id)
|
|
elif state == 'READY' or state == 'RUNNING':
|
|
print('Canceling task "%s"' % task_id)
|
|
ee.data.cancelTask(task_id)
|
|
elif not cancel_all:
|
|
print('Task "%s" already in state "%s".' % (status['id'], state))
|
|
|
|
|
|
class TaskInfoCommand:
|
|
"""Prints information about a task."""
|
|
|
|
name = 'info'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument('task_id', nargs='*', help='ID of a task to get.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Runs the TaskInfo command."""
|
|
config.ee_init()
|
|
for i, status in enumerate(ee.data.getTaskStatus(args.task_id)):
|
|
if i:
|
|
print()
|
|
print('%s:' % status['id'])
|
|
print(' State: %s' % status['state'])
|
|
if status['state'] == 'UNKNOWN':
|
|
continue
|
|
print(' Type: %s' % TASK_TYPES.get(status.get('task_type'), 'Unknown'))
|
|
print(' Description: %s' % status.get('description'))
|
|
print(' Created: %s' % _parse_millis(status['creation_timestamp_ms']))
|
|
if 'start_timestamp_ms' in status:
|
|
print(' Started: %s' % _parse_millis(status['start_timestamp_ms']))
|
|
if 'update_timestamp_ms' in status:
|
|
print(' Updated: %s' % _parse_millis(status['update_timestamp_ms']))
|
|
if 'error_message' in status:
|
|
print(' Error: %s' % status['error_message'])
|
|
if 'destination_uris' in status:
|
|
print(' Destination URIs: %s' % ', '.join(status['destination_uris']))
|
|
|
|
|
|
class TaskListCommand:
|
|
"""Lists the tasks submitted recently."""
|
|
|
|
name = 'list'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'--status', '-s', required=False, nargs='*',
|
|
choices=['READY', 'RUNNING', 'COMPLETED', 'FAILED',
|
|
'CANCELLED', 'UNKNOWN'],
|
|
help=('List tasks only with a given status'))
|
|
parser.add_argument(
|
|
'--long_format',
|
|
'-l',
|
|
action='store_true',
|
|
help=('Print output in long format. Extra columns are: creation time, '
|
|
'start time, update time, EECU-seconds, output URLs.')
|
|
)
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Lists tasks present for a user, maybe filtering by state."""
|
|
config.ee_init()
|
|
status = args.status
|
|
tasks = ee.data.getTaskList()
|
|
descs = [utils.truncate(task.get('description', ''), 40) for task in tasks]
|
|
desc_length = max((len(word) for word in descs), default=0)
|
|
format_str = '{:25s} {:13s} {:%ds} {:10s} {:s}' % (desc_length + 1)
|
|
for task in tasks:
|
|
if status and task['state'] not in status:
|
|
continue
|
|
truncated_desc = utils.truncate(task.get('description', ''), 40)
|
|
task_type = TASK_TYPES.get(task['task_type'], 'Unknown')
|
|
extra = ''
|
|
if args.long_format:
|
|
show_date = lambda ms: _parse_millis(ms).strftime('%Y-%m-%d %H:%M:%S')
|
|
eecu = '{:.4f}'.format(
|
|
task['batch_eecu_usage_seconds']
|
|
) if 'batch_eecu_usage_seconds' in task else '-'
|
|
extra = ' {:20s} {:20s} {:20s} {:11s} {}'.format(
|
|
show_date(task['creation_timestamp_ms']),
|
|
show_date(task['start_timestamp_ms']),
|
|
show_date(task['update_timestamp_ms']),
|
|
eecu,
|
|
' '.join(task.get('destination_uris', [])))
|
|
print(format_str.format(
|
|
task['id'], task_type, truncated_desc,
|
|
task['state'], task.get('error_message', '---')) + extra)
|
|
|
|
|
|
class TaskWaitCommand:
|
|
"""Waits for the specified task or tasks to complete."""
|
|
|
|
name = 'wait'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'--timeout', '-t', default=sys.maxsize, type=int,
|
|
help=('Stop waiting for the task(s) to finish after the specified,'
|
|
' number of seconds. Without this flag, the command will wait'
|
|
' indefinitely.'))
|
|
parser.add_argument('--verbose', '-v', action='store_true',
|
|
help=('Print periodic status messages for each'
|
|
' incomplete task.'))
|
|
parser.add_argument('task_ids', nargs='+',
|
|
help=('Either a list of one or more currently-running'
|
|
' task ids to wait on; or \'all\' to wait on all'
|
|
' running tasks.'))
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Waits on the given tasks to complete or for a timeout to pass."""
|
|
config.ee_init()
|
|
task_ids = []
|
|
if args.task_ids == ['all']:
|
|
tasks = ee.data.getTaskList()
|
|
for task in tasks:
|
|
if task['state'] not in utils.TASK_FINISHED_STATES:
|
|
task_ids.append(task['id'])
|
|
else:
|
|
statuses = ee.data.getTaskStatus(args.task_ids)
|
|
for status in statuses:
|
|
state = status['state']
|
|
task_id = status['id']
|
|
if state == 'UNKNOWN':
|
|
raise ee.EEException('Unknown task id "%s"' % task_id)
|
|
else:
|
|
task_ids.append(task_id)
|
|
|
|
utils.wait_for_tasks(task_ids, args.timeout, log_progress=args.verbose)
|
|
|
|
|
|
class TaskCommand(Dispatcher):
|
|
"""Prints information about or manages long-running tasks."""
|
|
|
|
name = 'task'
|
|
|
|
COMMANDS = [
|
|
TaskCancelCommand,
|
|
TaskInfoCommand,
|
|
TaskListCommand,
|
|
TaskWaitCommand,
|
|
]
|
|
|
|
|
|
# TODO(user): in both upload tasks, check if the parent namespace
|
|
# exists and is writeable first.
|
|
class UploadImageCommand:
|
|
"""Uploads an image from Cloud Storage to Earth Engine.
|
|
|
|
See docs for "asset set" for additional details on how to specify asset
|
|
metadata properties.
|
|
"""
|
|
|
|
name = 'image'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
_add_wait_arg(parser)
|
|
_add_overwrite_arg(parser)
|
|
parser.add_argument(
|
|
'src_files',
|
|
help=('Cloud Storage URL(s) of the file(s) to upload. '
|
|
'Must have the prefix \'gs://\'.'),
|
|
nargs='*')
|
|
parser.add_argument(
|
|
'--asset_id',
|
|
help='Destination asset ID for the uploaded file.')
|
|
parser.add_argument(
|
|
'--last_band_alpha',
|
|
help='Use the last band as a masking channel for all bands. '
|
|
'Mutually exclusive with nodata_value.',
|
|
action='store_true')
|
|
parser.add_argument(
|
|
'--nodata_value',
|
|
help='Value for missing data. '
|
|
'Mutually exclusive with last_band_alpha.',
|
|
type=_comma_separated_numbers)
|
|
parser.add_argument(
|
|
'--pyramiding_policy',
|
|
help='The pyramid reduction policy to use',
|
|
type=_comma_separated_pyramiding_policies)
|
|
parser.add_argument(
|
|
'--bands',
|
|
help='Comma-separated list of names to use for the image bands.',
|
|
type=_comma_separated_strings)
|
|
parser.add_argument(
|
|
'--crs',
|
|
help='The coordinate reference system, to override the map projection '
|
|
'of the image. May be either a well-known authority code (e.g. '
|
|
'EPSG:4326) or a WKT string.')
|
|
parser.add_argument(
|
|
'--manifest',
|
|
help='Local path to a JSON asset manifest file. No other flags are '
|
|
'used if this flag is set.')
|
|
_add_property_flags(parser)
|
|
|
|
def _check_num_bands(self, bands, num_bands, flag_name):
|
|
"""Checks the number of bands, creating them if there are none yet."""
|
|
if bands:
|
|
if len(bands) != num_bands:
|
|
raise ValueError(
|
|
'Inconsistent number of bands in --{}: expected {} but found {}.'
|
|
.format(flag_name, len(bands), num_bands))
|
|
else:
|
|
bands = ['b%d' % (i + 1) for i in range(num_bands)]
|
|
return bands
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Starts the upload task, and waits for completion if requested."""
|
|
config.ee_init()
|
|
manifest = self.manifest_from_args(args)
|
|
_upload(args, manifest, ee.data.startIngestion)
|
|
|
|
def manifest_from_args(self, args):
|
|
"""Constructs an upload manifest from the command-line flags."""
|
|
|
|
def is_tf_record(path):
|
|
if any(path.lower().endswith(extension)
|
|
for extension in TF_RECORD_EXTENSIONS):
|
|
return True
|
|
return False
|
|
|
|
if args.manifest:
|
|
with open(args.manifest) as fh:
|
|
return json.loads(fh.read())
|
|
|
|
if not args.asset_id:
|
|
raise ValueError('Flag --asset_id must be set.')
|
|
|
|
_check_valid_files(args.src_files)
|
|
if args.last_band_alpha and args.nodata_value:
|
|
raise ValueError(
|
|
'last_band_alpha and nodata_value are mutually exclusive.')
|
|
|
|
properties = _decode_property_flags(args)
|
|
source_files = list(utils.expand_gcs_wildcards(args.src_files))
|
|
if not source_files:
|
|
raise ValueError('At least one file must be specified.')
|
|
|
|
bands = args.bands
|
|
if args.pyramiding_policy and len(args.pyramiding_policy) != 1:
|
|
bands = self._check_num_bands(bands, len(args.pyramiding_policy),
|
|
'pyramiding_policy')
|
|
if args.nodata_value and len(args.nodata_value) != 1:
|
|
bands = self._check_num_bands(bands, len(args.nodata_value),
|
|
'nodata_value')
|
|
|
|
args.asset_id = ee.data.convert_asset_id_to_asset_name(args.asset_id)
|
|
# If we are ingesting a tfrecord, we actually treat the inputs as one
|
|
# source and many uris.
|
|
if any(is_tf_record(source) for source in source_files):
|
|
tileset = {'id': 'ts', 'sources': [{'uris': list(source_files)}]}
|
|
else:
|
|
tileset = {
|
|
'id': 'ts',
|
|
'sources': [{'uris': [source]} for source in source_files]
|
|
}
|
|
if args.crs:
|
|
tileset['crs'] = args.crs
|
|
manifest = {
|
|
'name': args.asset_id,
|
|
'properties': properties,
|
|
'tilesets': [tileset]
|
|
}
|
|
# pylint:disable=g-explicit-bool-comparison
|
|
if args.time_start is not None and args.time_start != '':
|
|
manifest['start_time'] = _cloud_timestamp_for_timestamp_ms(
|
|
args.time_start)
|
|
if args.time_end is not None and args.time_end != '':
|
|
manifest['end_time'] = _cloud_timestamp_for_timestamp_ms(args.time_end)
|
|
# pylint:enable=g-explicit-bool-comparison
|
|
|
|
if bands:
|
|
file_bands = []
|
|
for i, band in enumerate(bands):
|
|
file_bands.append({
|
|
'id': band,
|
|
'tilesetId': tileset['id'],
|
|
'tilesetBandIndex': i
|
|
})
|
|
manifest['bands'] = file_bands
|
|
|
|
if args.pyramiding_policy:
|
|
if len(args.pyramiding_policy) == 1:
|
|
manifest['pyramidingPolicy'] = args.pyramiding_policy[0]
|
|
else:
|
|
for index, policy in enumerate(args.pyramiding_policy):
|
|
file_bands[index]['pyramidingPolicy'] = policy
|
|
|
|
if args.nodata_value:
|
|
if len(args.nodata_value) == 1:
|
|
manifest['missingData'] = {'values': [args.nodata_value[0]]}
|
|
else:
|
|
for index, value in enumerate(args.nodata_value):
|
|
file_bands[index]['missingData'] = {'values': [value]}
|
|
|
|
if args.last_band_alpha:
|
|
manifest['maskBands'] = {'tilesetId': tileset['id']}
|
|
|
|
return manifest
|
|
|
|
|
|
# TODO(user): update src_files help string when secondary files
|
|
# can be uploaded.
|
|
class UploadTableCommand:
|
|
"""Uploads a table from Cloud Storage to Earth Engine."""
|
|
|
|
name = 'table'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
_add_wait_arg(parser)
|
|
_add_overwrite_arg(parser)
|
|
parser.add_argument(
|
|
'src_file',
|
|
help=('Cloud Storage URL of the .csv, .tfrecord, .shp, or '
|
|
'.zip file to upload. Must have the prefix \'gs://\'. For '
|
|
'.shp files, related .dbf, .shx, and .prj files must be '
|
|
'present in the same location.'),
|
|
nargs='*')
|
|
parser.add_argument(
|
|
'--asset_id',
|
|
help='Destination asset ID for the uploaded file.')
|
|
_add_property_flags(parser)
|
|
parser.add_argument(
|
|
'--charset',
|
|
help=(
|
|
'The name of the charset to use for decoding strings. If not '
|
|
'given, the charset "UTF-8" is assumed by default.'
|
|
),
|
|
type=str,
|
|
nargs='?',
|
|
)
|
|
parser.add_argument(
|
|
'--max_error',
|
|
help='Max allowed error in meters when transforming geometry '
|
|
'between coordinate systems.',
|
|
type=float, nargs='?')
|
|
parser.add_argument(
|
|
'--max_vertices',
|
|
help='Max number of vertices per geometry. If set, geometry will be '
|
|
'subdivided into spatially disjoint pieces each under this limit.',
|
|
type=int, nargs='?')
|
|
parser.add_argument(
|
|
'--max_failed_features',
|
|
help='The maximum number of failed features to allow during ingestion.',
|
|
type=int, nargs='?')
|
|
parser.add_argument(
|
|
'--crs',
|
|
help='The default CRS code or WKT string specifying the coordinate '
|
|
'reference system of any geometry without one. If unspecified, '
|
|
'the default will be EPSG:4326 (https://epsg.io/4326). For '
|
|
'CSV/TFRecord only.')
|
|
parser.add_argument(
|
|
'--geodesic',
|
|
help='The default strategy for interpreting edges in geometries that '
|
|
'do not have one specified. If false, edges are '
|
|
'straight in the projection. If true, edges are curved to follow '
|
|
'the shortest path on the surface of the Earth. When '
|
|
'unspecified, defaults to false if \'crs\' is a projected '
|
|
'coordinate system. For CSV/TFRecord only.',
|
|
action='store_true')
|
|
parser.add_argument(
|
|
'--primary_geometry_column',
|
|
help='The geometry column to use as a row\'s primary geometry when '
|
|
'there is more than one geometry column. If unspecified and more '
|
|
'than one geometry column exists, the first geometry column '
|
|
'is used. For CSV/TFRecord only.')
|
|
parser.add_argument(
|
|
'--x_column',
|
|
help='The name of the numeric x coordinate column for constructing '
|
|
'point geometries. If the y_column is also specified, and both '
|
|
'columns contain numerical values, then a point geometry column '
|
|
'will be constructed with x,y values in the coordinate system '
|
|
'given in \'--crs\'. If unspecified and \'--crs\' does _not_ '
|
|
'specify a projected coordinate system, defaults to "longitude". '
|
|
'If unspecified and \'--crs\' _does_ specify a projected '
|
|
'coordinate system, defaults to "" and no point geometry is '
|
|
'generated. A generated point geometry column will be named '
|
|
'{x_column}_{y_column}_N where N might be appended to '
|
|
'disambiguate the column name. For CSV/TFRecord only.')
|
|
parser.add_argument(
|
|
'--y_column',
|
|
help='The name of the numeric y coordinate column for constructing '
|
|
'point geometries. If the x_column is also specified, and both '
|
|
'columns contain numerical values, then a point geometry column '
|
|
'will be constructed with x,y values in the coordinate system '
|
|
'given in \'--crs\'. If unspecified and \'--crs\' does _not_ '
|
|
'specify a projected coordinate system, defaults to "latitude". '
|
|
'If unspecified and \'--crs\' _does_ specify a projected '
|
|
'coordinate system, defaults to "" and no point geometry is '
|
|
'generated. A generated point geometry column will be named '
|
|
'{x_column}_{y_column}_N where N might be appended to '
|
|
'disambiguate the column name. For CSV/TFRecord only.')
|
|
# pylint: disable=line-too-long
|
|
parser.add_argument(
|
|
'--date_format',
|
|
help='A format used to parse dates. The format pattern must follow '
|
|
'http://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html. '
|
|
'If unspecified, dates will be imported as strings. For '
|
|
'CSV/TFRecord only.')
|
|
# pylint: enable=line-too-long
|
|
parser.add_argument(
|
|
'--csv_delimiter',
|
|
help='A single character used as a delimiter between column values '
|
|
'in a row. If unspecified, defaults to \',\'. For CSV only.')
|
|
parser.add_argument(
|
|
'--csv_qualifier',
|
|
help='A character that surrounds column values (a.k.a. '
|
|
'\'quote character\'). If unspecified, defaults to \'"\'. A '
|
|
'column value may include the qualifier as a literal character by '
|
|
'having 2 consecutive qualifier characters. For CSV only.')
|
|
parser.add_argument(
|
|
'--manifest',
|
|
help='Local path to a JSON asset manifest file. No other flags are '
|
|
'used if this flag is set.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Starts the upload task, and waits for completion if requested."""
|
|
config.ee_init()
|
|
manifest = self.manifest_from_args(args)
|
|
_upload(args, manifest, ee.data.startTableIngestion)
|
|
|
|
def manifest_from_args(self, args):
|
|
"""Constructs an upload manifest from the command-line flags."""
|
|
|
|
if args.manifest:
|
|
with open(args.manifest) as fh:
|
|
return json.loads(fh.read())
|
|
|
|
if not args.asset_id:
|
|
raise ValueError('Flag --asset_id must be set.')
|
|
|
|
_check_valid_files(args.src_file)
|
|
source_files = list(utils.expand_gcs_wildcards(args.src_file))
|
|
if len(source_files) != 1:
|
|
raise ValueError('Exactly one file must be specified.')
|
|
|
|
properties = _decode_property_flags(args)
|
|
args.asset_id = ee.data.convert_asset_id_to_asset_name(args.asset_id)
|
|
source = {'uris': source_files}
|
|
if args.charset:
|
|
source['charset'] = args.charset
|
|
if args.max_error:
|
|
source['maxErrorMeters'] = args.max_error
|
|
if args.max_vertices:
|
|
source['maxVertices'] = args.max_vertices
|
|
if args.max_failed_features:
|
|
raise ee.EEException(
|
|
'--max_failed_features is not supported with the Cloud API')
|
|
if args.crs:
|
|
source['crs'] = args.crs
|
|
if args.geodesic:
|
|
source['geodesic'] = args.geodesic
|
|
if args.primary_geometry_column:
|
|
source['primary_geometry_column'] = args.primary_geometry_column
|
|
if args.x_column:
|
|
source['x_column'] = args.x_column
|
|
if args.y_column:
|
|
source['y_column'] = args.y_column
|
|
if args.date_format:
|
|
source['date_format'] = args.date_format
|
|
if args.csv_delimiter:
|
|
source['csv_delimiter'] = args.csv_delimiter
|
|
if args.csv_qualifier:
|
|
source['csv_qualifier'] = args.csv_qualifier
|
|
|
|
manifest = {
|
|
'name': args.asset_id,
|
|
'sources': [source],
|
|
'properties': properties
|
|
}
|
|
|
|
# pylint:disable=g-explicit-bool-comparison
|
|
if args.time_start is not None and args.time_start != '':
|
|
manifest['start_time'] = _cloud_timestamp_for_timestamp_ms(
|
|
args.time_start)
|
|
if args.time_end is not None and args.time_end != '':
|
|
manifest['end_time'] = _cloud_timestamp_for_timestamp_ms(args.time_end)
|
|
# pylint:enable=g-explicit-bool-comparison
|
|
return manifest
|
|
|
|
|
|
class UploadCommand(Dispatcher):
|
|
"""Uploads assets to Earth Engine."""
|
|
|
|
name = 'upload'
|
|
|
|
COMMANDS = [
|
|
UploadImageCommand,
|
|
UploadTableCommand,
|
|
]
|
|
|
|
|
|
class _UploadManifestBase:
|
|
"""Uploads an asset to Earth Engine using the given manifest file."""
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
_add_wait_arg(parser)
|
|
_add_overwrite_arg(parser)
|
|
parser.add_argument(
|
|
'manifest',
|
|
help=('Local path to a JSON asset manifest file.'))
|
|
|
|
def run(self, args: argparse.Namespace, config, ingestion_function) -> None:
|
|
"""Starts the upload task, and waits for completion if requested."""
|
|
config.ee_init()
|
|
with open(args.manifest) as fh:
|
|
manifest = json.loads(fh.read())
|
|
|
|
_upload(args, manifest, ingestion_function)
|
|
|
|
|
|
class UploadImageManifestCommand(_UploadManifestBase):
|
|
"""Uploads an image to Earth Engine using the given manifest file."""
|
|
|
|
name = 'upload_manifest'
|
|
|
|
# pytype: disable=signature-mismatch
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
# pytype: enable=signature-mismatch
|
|
"""Starts the upload task, and waits for completion if requested."""
|
|
print(
|
|
'This command is deprecated. '
|
|
'Use "earthengine upload image --manifest".'
|
|
)
|
|
super().run(args, config, ee.data.startIngestion)
|
|
|
|
|
|
class UploadTableManifestCommand(_UploadManifestBase):
|
|
"""Uploads a table to Earth Engine using the given manifest file."""
|
|
|
|
name = 'upload_table_manifest'
|
|
|
|
# pytype: disable=signature-mismatch
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
# pytype: enable=signature-mismatch
|
|
print(
|
|
'This command is deprecated. '
|
|
'Use "earthengine upload table --manifest".'
|
|
)
|
|
super().run(args, config, ee.data.startTableIngestion)
|
|
|
|
|
|
def _get_nodes(node_spec, source_flag_name):
|
|
"""Extract a node mapping from a list or flag-specified JSON."""
|
|
try:
|
|
spec = json.loads(node_spec)
|
|
except ValueError:
|
|
spec = [n.strip() for n in node_spec.split(',')]
|
|
return {item: item for item in spec}
|
|
|
|
if not isinstance(spec, dict):
|
|
raise ValueError(
|
|
'If flag {} is JSON it must specify a dictionary.'.format(
|
|
source_flag_name))
|
|
|
|
for k, v in spec.items():
|
|
if (not isinstance(k, str) or not isinstance(v, str)):
|
|
raise ValueError(
|
|
'All key/value pairs of the dictionary specified in ' +
|
|
f'{source_flag_name} must be strings.')
|
|
|
|
return spec
|
|
|
|
|
|
def _validate_and_extract_nodes(args):
|
|
"""Validate command line args and extract in/out node mappings."""
|
|
if not args.source_dir:
|
|
raise ValueError('Flag --source_dir must be set.')
|
|
if not args.dest_dir:
|
|
raise ValueError('Flag --dest_dir must be set.')
|
|
if not args.input:
|
|
raise ValueError('Flag --input must be set.')
|
|
if not args.output:
|
|
raise ValueError('Flag --output must be set.')
|
|
|
|
return (_get_nodes(args.input, '--input'),
|
|
_get_nodes(args.output, '--output'))
|
|
|
|
|
|
def _encode_op(output_tensor, name):
|
|
return tf.identity(
|
|
tf.map_fn(lambda x: tf.io.encode_base64(tf.serialize_tensor(x)),
|
|
output_tensor, tf.string),
|
|
name=name)
|
|
|
|
|
|
def _decode_op(input_tensor, dtype):
|
|
mapped = tf.map_fn(lambda x: tf.parse_tensor(tf.io.decode_base64(x), dtype),
|
|
input_tensor, dtype)
|
|
return mapped
|
|
|
|
|
|
def _strip_index(edge_name):
|
|
colon_pos = edge_name.rfind(':')
|
|
if colon_pos == -1:
|
|
return edge_name
|
|
else:
|
|
return edge_name[:colon_pos]
|
|
|
|
|
|
def _get_input_tensor_spec(graph_def, input_names_set):
|
|
"""Extracts the types of the given node names from the GraphDef."""
|
|
|
|
# Get the op names stripped of the input index, e.g. "op:0" becomes "op".
|
|
input_names_missing_index = {_strip_index(i): i for i in input_names_set}
|
|
|
|
spec = {}
|
|
for cur_node in graph_def.node:
|
|
if cur_node.name in input_names_missing_index:
|
|
if 'shape' not in cur_node.attr or 'dtype' not in cur_node.attr:
|
|
raise ValueError(
|
|
'Specified input op is not a valid graph input: \'{}\'.'.format(
|
|
cur_node.name))
|
|
|
|
spec[input_names_missing_index[cur_node.name]] = tf.dtypes.DType(
|
|
cur_node.attr['dtype'].type)
|
|
|
|
if len(spec) != len(input_names_set):
|
|
raise ValueError(
|
|
'Specified input ops were missing from graph: {}.'.format(
|
|
list(set(input_names_set).difference(list(spec.keys())))))
|
|
return spec
|
|
|
|
|
|
def _make_rpc_friendly(model_dir, tag, in_map, out_map, vars_path):
|
|
"""Wraps a SavedModel in EE RPC-friendly ops and saves a temporary copy."""
|
|
out_dir = tempfile.mkdtemp()
|
|
builder = tf.saved_model.Builder(out_dir)
|
|
|
|
# Get a GraphDef from the saved model
|
|
with tf.Session() as sesh:
|
|
meta_graph = tf.saved_model.load(sesh, [tag], model_dir)
|
|
|
|
graph_def = meta_graph.graph_def
|
|
|
|
# Purge the default graph immediately after: we want to remap parts of the
|
|
# graph when we load it and we don't know what those parts are yet.
|
|
tf.reset_default_graph()
|
|
|
|
input_op_keys = list(in_map.keys())
|
|
input_new_keys = list(in_map.values())
|
|
|
|
# Get the shape and type of the input tensors
|
|
in_op_types = _get_input_tensor_spec(graph_def, input_op_keys)
|
|
|
|
# Create new input placeholders to receive RPC TensorProto payloads
|
|
in_op_map = {
|
|
k: tf.placeholder(
|
|
tf.string, shape=[None], name='earthengine_in_{}'.format(i))
|
|
for (i, k) in enumerate(input_new_keys)
|
|
}
|
|
|
|
# Glue on decoding ops to remap to the imported graph.
|
|
decoded_op_map = {
|
|
k: _decode_op(in_op_map[in_map[k]], in_op_types[k])
|
|
for k in input_op_keys
|
|
}
|
|
|
|
# Okay now we're ready to import the graph again but remapped.
|
|
saver = tf.train.import_meta_graph(
|
|
meta_graph_or_file=meta_graph, input_map=decoded_op_map)
|
|
|
|
# Boilerplate to build a signature def for our new graph
|
|
sig_in = {
|
|
_strip_index(k):
|
|
saved_model_utils.build_tensor_info(v) for (k, v) in in_op_map.items()
|
|
}
|
|
|
|
sig_out = {}
|
|
for index, (k, v) in enumerate(out_map.items()):
|
|
out_tensor = saved_model_utils.build_tensor_info(
|
|
_encode_op(
|
|
tf.get_default_graph().get_tensor_by_name(k),
|
|
name='earthengine_out_{}'.format(index)))
|
|
|
|
sig_out[_strip_index(v)] = out_tensor
|
|
|
|
sig_def = signature_def_utils.build_signature_def(
|
|
sig_in, sig_out, signature_constants.PREDICT_METHOD_NAME)
|
|
|
|
# Open a new session to load the variables and add them to the builder.
|
|
with tf.Session() as sesh:
|
|
if saver:
|
|
saver.restore(sesh, model_dir + vars_path)
|
|
builder.add_meta_graph_and_variables(
|
|
sesh,
|
|
tags=[tf.saved_model.tag_constants.SERVING],
|
|
signature_def_map={
|
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
|
|
},
|
|
saver=saver)
|
|
|
|
builder.save()
|
|
return out_dir
|
|
|
|
|
|
class PrepareModelCommand:
|
|
"""Prepares a TensorFlow/Keras SavedModel for inference with Earth Engine.
|
|
|
|
This is required only if a model is manually uploaded to Cloud AI Platform
|
|
(https://cloud.google.com/ai-platform/) for predictions.
|
|
"""
|
|
|
|
name = 'prepare'
|
|
|
|
def __init__(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
'--source_dir',
|
|
help='The local or Cloud Storage path to directory containing the '
|
|
'SavedModel.')
|
|
parser.add_argument(
|
|
'--dest_dir',
|
|
help='The name of the directory to be created locally or in Cloud '
|
|
'Storage that will contain the Earth Engine ready SavedModel.')
|
|
parser.add_argument(
|
|
'--input',
|
|
help='A comma-delimited list of input node names that will map to '
|
|
'Earth Engine Feature columns or Image bands for prediction, or a JSON '
|
|
'dictionary specifying a remapping of input node names to names '
|
|
'mapping to Feature columns or Image bands etc... (e.x: '
|
|
'\'{"Conv2D:0":"my_landsat_band"}\'). The names of model inputs will '
|
|
'be stripped of any trailing \'<:prefix>\'.')
|
|
parser.add_argument(
|
|
'--output',
|
|
help='A comma-delimited list of output tensor names that will map to '
|
|
'Earth Engine Feature columns or Image bands for prediction, or a JSON '
|
|
'dictionary specifying a remapping of output node names to names '
|
|
'mapping to Feature columns or Image bands etc... (e.x: '
|
|
'\'{"Sigmoid:0":"my_predicted_class"}\'). The names of model outputs '
|
|
'will be stripped of any trailing \'<:prefix>\'.')
|
|
parser.add_argument(
|
|
'--tag',
|
|
help='An optional tag used to load a specific graph from the '
|
|
'SavedModel. Defaults to \'serve\'.')
|
|
parser.add_argument(
|
|
'--variables',
|
|
help='An optional relative path from within the source directory to '
|
|
'the prefix of the model variables. (e.x: if the model variables are '
|
|
'stored under \'model_dir/variables/x.*\', set '
|
|
'--variables=/variables/x). Defaults to \'/variables/variables\'.')
|
|
|
|
def run(
|
|
self, args: argparse.Namespace, config: utils.CommandLineConfig
|
|
) -> None:
|
|
"""Wraps a SavedModel in EE RPC-friendly ops and saves a copy of it."""
|
|
check_tensorflow_installed()
|
|
|
|
in_spec, out_spec = _validate_and_extract_nodes(args)
|
|
gcs_client = None
|
|
|
|
if utils.is_gcs_path(args.source_dir):
|
|
# If the model isn't locally available, we have to make it available...
|
|
gcs_client = config.create_gcs_helper()
|
|
gcs_client.check_gcs_dir_within_size(args.source_dir,
|
|
SAVED_MODEL_MAX_SIZE)
|
|
local_model_dir = gcs_client.download_dir_to_temp(args.source_dir)
|
|
else:
|
|
local_model_dir = args.source_dir
|
|
|
|
tag = args.tag if args.tag else tf.saved_model.tag_constants.SERVING
|
|
vars_path = args.variables if args.variables else DEFAULT_VARIABLES_PREFIX
|
|
new_model_dir = _make_rpc_friendly(
|
|
local_model_dir, tag, in_spec, out_spec, vars_path)
|
|
|
|
if utils.is_gcs_path(args.dest_dir):
|
|
if not gcs_client:
|
|
gcs_client = config.create_gcs_helper()
|
|
gcs_client.upload_dir_to_bucket(new_model_dir, args.dest_dir)
|
|
else:
|
|
shutil.move(new_model_dir, args.dest_dir)
|
|
|
|
print(
|
|
'Success: model at \'{}\' is ready to be hosted in AI Platform.'.format(
|
|
args.dest_dir))
|
|
|
|
|
|
def check_tensorflow_installed():
|
|
"""Checks the status of TensorFlow installations."""
|
|
if not TENSORFLOW_INSTALLED:
|
|
raise ImportError(
|
|
'By default, TensorFlow is not installed with Earth Engine client '
|
|
'libraries. To use \'model\' commands, make sure at least TensorFlow '
|
|
'1.14 is installed; you can do this by executing \'pip install '
|
|
'tensorflow\' in your shell.'
|
|
)
|
|
else:
|
|
if not TENSORFLOW_ADDONS_INSTALLED:
|
|
print(
|
|
'Warning: TensorFlow Addons not found. Models that use '
|
|
'non-standard ops may not work.')
|
|
|
|
|
|
class ModelCommand(Dispatcher):
|
|
"""TensorFlow model related commands."""
|
|
|
|
name = 'model'
|
|
|
|
COMMANDS = [PrepareModelCommand]
|
|
|
|
EXTERNAL_COMMANDS = [
|
|
AuthenticateCommand,
|
|
AclCommand,
|
|
AssetCommand,
|
|
CopyCommand,
|
|
CreateCommand,
|
|
ListCommand,
|
|
SizeCommand,
|
|
MoveCommand,
|
|
ModelCommand,
|
|
RmCommand,
|
|
SetProjectCommand,
|
|
TaskCommand,
|
|
UnSetProjectCommand,
|
|
UploadCommand,
|
|
UploadImageManifestCommand,
|
|
UploadTableManifestCommand,
|
|
]
|