Nate Schmitz 3d4a6083ff Internal changes.
PiperOrigin-RevId: 812912320
2025-09-29 13:09:08 -07:00

106 lines
2.7 KiB
Python

"""Stores the current state of the EE library."""
from __future__ import annotations
import dataclasses
import re
from typing import Any
# Rename to avoid redefined-outer-name warning.
from google.oauth2 import credentials as credentials_lib
import requests
# The default project to use for Cloud API calls.
DEFAULT_CLOUD_API_USER_PROJECT = 'earthengine-legacy'
@dataclasses.dataclass()
class _WorkloadTag:
"""A helper class to manage the workload tag."""
_tag: str
_default: str
def __init__(self):
self._tag = ''
self._default = ''
def get(self) -> str:
return self._tag
def set(self, tag: int | str | None) -> None:
self._tag = self.validate(tag)
def set_default(self, new_default: int | str | None) -> None:
self._default = self.validate(new_default)
def reset(self) -> None:
self._tag = self._default
def validate(self, tag: int | str | None) -> str:
"""Returns the validated tag or raises a ValueError.
Args:
tag: the tag to validate.
Raises:
ValueError if the tag does not match the expected format.
"""
if not tag and tag != 0:
return ''
tag = str(tag)
if not re.fullmatch(r'([a-z0-9]|[a-z0-9][-_a-z0-9]{0,61}[a-z0-9])', tag):
validation_message = (
'Tags must be 1-63 characters, '
'beginning and ending with a lowercase alphanumeric character '
'([a-z0-9]) with dashes (-), underscores (_), '
'and lowercase alphanumerics between.'
)
raise ValueError(f'Invalid tag, "{tag}". {validation_message}')
return tag
@dataclasses.dataclass()
class EEState:
"""Holds the configuration and initialized resources for an EE session."""
initialized: bool = False
credentials: credentials_lib.Credentials | None = None
api_base_url: str | None = None
tile_base_url: str | None = None
cloud_api_base_url: str | None = None
cloud_api_key: str | None = None
requests_session: requests.Session | None = None
cloud_api_resource: Any | None = None
cloud_api_resource_raw: Any | None = None
cloud_api_user_project: str | None = DEFAULT_CLOUD_API_USER_PROJECT
cloud_api_client_version: str | None = None
http_transport: Any | None = None
deadline_ms: int = 0
max_retries: int = 5
user_agent: str | None = None
workload_tag: _WorkloadTag = dataclasses.field(default_factory=_WorkloadTag)
_state: EEState | None = None
def get_state() -> EEState:
"""Retrieves the EEState for the current execution context.
Returns:
The EEState for the current execution context.
"""
global _state
if not _state:
_state = EEState()
return _state
def reset_state() -> None:
"""Resets the EEState to the default state."""
global _state
_state = None