From b9a91e9c3d4b82a24ee739898fc1c92e4256aff6 Mon Sep 17 00:00:00 2001 From: Xee authors Date: Wed, 8 Jan 2025 19:31:42 -0800 Subject: [PATCH] Add argument to generate ee credentials on-demand. PiperOrigin-RevId: 713500577 --- xee/ext.py | 37 +++++++++++++++++++++++++++++++++++-- xee/ext_test.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/xee/ext.py b/xee/ext.py index 78663ab..c6890ad 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -19,6 +19,7 @@ from __future__ import annotations import concurrent.futures +import copy import functools import importlib import itertools @@ -26,7 +27,7 @@ import math import os import sys -from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union from urllib import parse import warnings @@ -805,7 +806,8 @@ def _ee_init_check(self): 'Attempting to initialize using application default credentials.' ) - ee.Initialize(**(self.store.ee_init_kwargs or {})) + ee_init_kwargs = _parse_ee_init_kwargs(self.store.ee_init_kwargs) + ee.Initialize(**ee_init_kwargs) def __getitem__(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: return indexing.explicit_indexing_adapter( @@ -1165,3 +1167,34 @@ def open_dataset( ) return ds + + +def _parse_ee_init_kwargs( + ee_init_kwargs: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + """Parses Earth Engine Initialize kwargs. + + Generate credentials if credentials_function is specified. + + Args: + ee_init_kwargs: A dictionary of keyword arguments to pass to Earth Engine + Initialize, or None. + + Returns: + A dictionary of keyword arguments to pass to Earth Engine Initialize. + """ + ee_init_kwargs = copy.copy(ee_init_kwargs) or {} + if ( + 'credentials' in ee_init_kwargs + and 'credentials_function' in ee_init_kwargs + ): + raise ValueError( + 'Cannot specify both credentials and credentials_function.' + ) + if 'credentials_function' in ee_init_kwargs: + credentials_function: Callable[[], Any] = ee_init_kwargs.pop( + 'credentials_function' + ) + ee_init_kwargs['credentials'] = credentials_function() + + return ee_init_kwargs diff --git a/xee/ext_test.py b/xee/ext_test.py index 6ebd852..a873d2f 100644 --- a/xee/ext_test.py +++ b/xee/ext_test.py @@ -97,5 +97,49 @@ def test_exceeding_byte_limit__raises_error(self): ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT) +class ParseEEInitKwargsTest(absltest.TestCase): + + def test_parse_ee_init_kwargs__empty(self): + self.assertDictEqual(ext._parse_ee_init_kwargs(None), {}) + + def test_parse_ee_init_kwargs__credentials(self): + self.assertDictEqual( + ext._parse_ee_init_kwargs( + { + 'credentials': 'foo', + 'other': 'bar', + } + ), + { + 'credentials': 'foo', + 'other': 'bar', + }, + ) + + def test_parse_ee_init_kwargs__credentials_function(self): + self.assertDictEqual( + ext._parse_ee_init_kwargs( + { + 'credentials_function': lambda: 'foo', + 'other': 'bar', + } + ), + { + 'credentials': 'foo', + 'other': 'bar', + }, + ) + + def test_parse_ee_init_kwargs__credentials_and_credentials_function(self): + with self.assertRaises(ValueError): + ext._parse_ee_init_kwargs( + { + 'credentials': 'foo', + 'credentials_function': lambda: 'foo', + 'other': 'bar', + } + ) + + if __name__ == '__main__': absltest.main()