diff --git a/src/azure-cli-core/azure/cli/core/tests/test_util.py b/src/azure-cli-core/azure/cli/core/tests/test_util.py index ae329ed65b7..669f6075fe4 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_util.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_util.py @@ -17,7 +17,7 @@ (get_file_json, truncate_text, shell_safe_json_parse, b64_to_hex, hash_string, random_string, open_page_in_browser, can_launch_browser, handle_exception, ConfiguredDefaultSetter, send_raw_request, should_disable_connection_verify, parse_proxy_resource_id, get_az_user_agent, get_az_rest_user_agent, - _get_parent_proc_name, is_wsl, run_cmd, run_az_cmd, roughly_parse_command) + _get_parent_proc_name, is_wsl, run_cmd, run_az_cmd, roughly_parse_command, sdk_no_wait) from azure.cli.core.mock import DummyCli @@ -235,6 +235,42 @@ def test_configured_default_setter(self): self.assertEqual(config.use_local_config, False) self.assertTrue(config.use_local_config) + def test_sdk_no_wait_sets_polling_false(self): + func = mock.MagicMock(return_value='ok') + result = sdk_no_wait(True, func, 1, test='value') + self.assertEqual(result, 'ok') + func.assert_called_once_with(1, test='value', polling=False) + + def test_sdk_no_wait_retries_without_forced_polling_on_json_decode_error(self): + calls = [] + + def _func(*_args, **kwargs): + calls.append(kwargs.copy()) + if len(calls) == 1: + raise json.decoder.JSONDecodeError("bad json", "gAS", 3) + return 'ok' + + result = sdk_no_wait(True, _func) + self.assertEqual(result, 'ok') + self.assertEqual(len(calls), 2) + self.assertFalse(calls[0]['polling']) + self.assertNotIn('polling', calls[1]) + + def test_sdk_no_wait_restores_original_polling_on_json_decode_error(self): + calls = [] + + def _func(*_args, **kwargs): + calls.append(kwargs.copy()) + if len(calls) == 1: + raise json.decoder.JSONDecodeError("bad json", "gAS", 3) + return 'ok' + + result = sdk_no_wait(True, _func, polling=True) + self.assertEqual(result, 'ok') + self.assertEqual(len(calls), 2) + self.assertFalse(calls[0]['polling']) + self.assertTrue(calls[1]['polling']) + @mock.patch('azure.cli.core.__version__', '7.8.9') def test_get_az_user_agent(self): from azure.cli.core._environment import _ENV_AZ_INSTALLER diff --git a/src/azure-cli-core/azure/cli/core/util.py b/src/azure-cli-core/azure/cli/core/util.py index b693ee00518..b697e0d6448 100644 --- a/src/azure-cli-core/azure/cli/core/util.py +++ b/src/azure-cli-core/azure/cli/core/util.py @@ -790,7 +790,19 @@ def augment_no_wait_handler_args(no_wait_enabled, handler, handler_args): def sdk_no_wait(no_wait, func, *args, **kwargs): if no_wait: + original_polling = kwargs.get('polling', None) + has_polling = 'polling' in kwargs kwargs.update({'polling': False}) + try: + return func(*args, **kwargs) + except json.JSONDecodeError: + logger.debug("Retrying no-wait operation with original polling settings after JSON decode failure.", + exc_info=True) + if has_polling: + kwargs['polling'] = original_polling + else: + kwargs.pop('polling', None) + return func(*args, **kwargs) return func(*args, **kwargs)