2024-07-30 13:15:39 -07:00

101 lines
2.8 KiB
Python

"""A wrapper for Models."""
from typing import Any, Union
from ee import apifunction
from ee import computedobject
from ee import featurecollection
from ee import image
_FeatureCollectionType = Union[
Any, featurecollection.FeatureCollection, computedobject.ComputedObject
]
_ImageType = Union[Any, image.Image, computedobject.ComputedObject]
class Model(computedobject.ComputedObject):
"""An object to represent an Earth Engine Model.
Example:
model = ee.Model.fromVertexAi(
endpoint='endpoint-name',
inputTileSize=[8, 8],
outputBands={
'probability': {'type': ee.PixelType.float(), 'dimensions': 1}
},
)
Please visit one of the following links for more info:
- https://developers.google.com/earth-engine/guides/machine-learning
- https://developers.google.com/earth-engine/guides/tensorflow-vertex
"""
_initialized: bool = False
def __init__(self, model: computedobject.ComputedObject):
"""Creates a Model wrapper.
Args:
model: A Model to cast.
"""
self.initialize()
if isinstance(model, computedobject.ComputedObject):
# There is no server-side constructor for ee.Model. Pass the object as-is
# to the server in case it is intended to be a Model cast.
super().__init__(model.func, model.args, model.varName)
return
raise TypeError('Model constructor can only cast to Model.')
@classmethod
def initialize(cls) -> None:
"""Imports API functions to this class."""
if not cls._initialized:
apifunction.ApiFunction.importApi(cls, cls.name(), cls.name())
cls._initialized = True
@classmethod
def reset(cls) -> None:
"""Removes imported API functions from this class."""
apifunction.ApiFunction.clearApi(cls)
cls._initialized = False
@staticmethod
def name() -> str:
return 'Model'
# TODO: Add fromAiPlatformPredictor
# TODO: Add fromVertexAi
def predictImage(self, image: _ImageType) -> image.Image:
"""Returns an image with predictions from pixel tiles of an image.
The predictions are merged as bands with the input image.
The model will receive 0s in place of masked pixels. The masks of predicted
output bands are the minimum of the masks of the inputs.
Args:
image: The input image.
"""
return apifunction.ApiFunction.call_(
self.name() + '.predictImage', self, image
)
def predictProperties(
self, collection: _FeatureCollectionType
) -> featurecollection.FeatureCollection:
"""Returns a feature collection with predictions for each feature.
Predicted properties are merged with the properties of the input feature.
Args:
collection: The input collection.
"""
return apifunction.ApiFunction.call_(
self.name() + '.predictProperties', self, collection
)