diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 53c983edfad..dce0ce1f643 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -27,6 +27,19 @@ Available fields and semantics: cpus: 4+ # number of vCPUs, max concurrent spot jobs = 2 * cpus disk_size: 100 + # Allow list for clouds to be used in `sky check` + # + # This field is used to restrict the clouds that SkyPilot will check and use + # when running `sky check`. Any cloud already enabled but not specified here + # will be disabled on the next `sky check` run. + # If this field is not set, SkyPilot will check and use all supported clouds. + # + # Default: null (use all supported clouds). + allowed_clouds: + - aws + - gcp + - kubernetes + # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. aws: diff --git a/sky/adaptors/cloudflare.py b/sky/adaptors/cloudflare.py index 2a49dc6fff0..864248614f3 100644 --- a/sky/adaptors/cloudflare.py +++ b/sky/adaptors/cloudflare.py @@ -23,6 +23,7 @@ R2_PROFILE_NAME = 'r2' _INDENT_PREFIX = ' ' NAME = 'Cloudflare' +SKY_CHECK_NAME = 'Cloudflare (for R2 object store)' @contextlib.contextmanager diff --git a/sky/check.py b/sky/check.py index d90fdffefb7..f4ecd5a8b18 100644 --- a/sky/check.py +++ b/sky/check.py @@ -1,7 +1,7 @@ """Credential checks: check cloud credentials and enable clouds.""" import traceback from types import ModuleType -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import click import colorama @@ -10,6 +10,7 @@ from sky import clouds as sky_clouds from sky import exceptions from sky import global_user_state +from sky import skypilot_config from sky.adaptors import cloudflare from sky.utils import ux_utils @@ -52,20 +53,42 @@ def check_one_cloud( disabled_clouds.append(cloud_repr) echo(f' Reason: {reason}') + def get_cloud_tuple( + cloud_name: str) -> Tuple[str, Union[sky_clouds.Cloud, ModuleType]]: + # Validates cloud_name and returns a tuple of the cloud's name and + # the cloud object. Includes special handling for Cloudflare. + if cloud_name.lower().startswith('cloudflare'): + return cloudflare.SKY_CHECK_NAME, cloudflare + else: + cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud_name) + assert cloud_obj is not None, f'Cloud {cloud_name!r} not found' + return repr(cloud_obj), cloud_obj + + def get_all_clouds(): + return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] + + [cloudflare.SKY_CHECK_NAME]) + if clouds is not None: - clouds_to_check: List[Tuple[str, Any]] = [] - for cloud in clouds: - if cloud.lower() == 'cloudflare': - clouds_to_check.append( - ('Cloudflare, for R2 object store', cloudflare)) - else: - cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud) - assert cloud_obj is not None, f'Cloud {cloud!r} not found' - clouds_to_check.append((repr(cloud_obj), cloud_obj)) + cloud_list = clouds else: - clouds_to_check = [(repr(cloud_obj), cloud_obj) - for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()] - clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare)) + cloud_list = get_all_clouds() + clouds_to_check = [get_cloud_tuple(c) for c in cloud_list] + + # Use allowed_clouds from config if it exists, otherwise check all clouds. + # Also validate names with get_cloud_tuple. + config_allowed_cloud_names = [ + get_cloud_tuple(c)[0] for c in skypilot_config.get_nested( + ['allowed_clouds'], get_all_clouds()) + ] + # Use disallowed_cloud_names for logging the clouds that will be disabled + # because they are not included in allowed_clouds in config.yaml. + disallowed_cloud_names = [ + c for c in get_all_clouds() if c not in config_allowed_cloud_names + ] + # Check only the clouds which are allowed in the config. + clouds_to_check = [ + c for c in clouds_to_check if c[0] in config_allowed_cloud_names + ] for cloud_tuple in sorted(clouds_to_check): check_one_cloud(cloud_tuple) @@ -79,16 +102,30 @@ def check_one_cloud( disabled_clouds_set = { cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare') } + config_allowed_clouds_set = { + cloud for cloud in config_allowed_cloud_names + if not cloud.startswith('Cloudflare') + } previously_enabled_clouds_set = { repr(cloud) for cloud in global_user_state.get_cached_enabled_clouds() } - # Determine the set of enabled clouds: previously enabled clouds + newly - # enabled clouds - newly disabled clouds. - all_enabled_clouds = ((previously_enabled_clouds_set | enabled_clouds_set) - - disabled_clouds_set) + # Determine the set of enabled clouds: (previously enabled clouds + newly + # enabled clouds - newly disabled clouds) intersected with + # config_allowed_clouds, if specified in config.yaml. + # This means that if a cloud is already enabled and is not included in + # allowed_clouds in config.yaml, it will be disabled. + all_enabled_clouds = (config_allowed_clouds_set & ( + (previously_enabled_clouds_set | enabled_clouds_set) - + disabled_clouds_set)) global_user_state.set_enabled_clouds(list(all_enabled_clouds)) + disallowed_clouds_hint = None + if disallowed_cloud_names: + disallowed_clouds_hint = ( + '\nNote: The following clouds were disabled because they were not ' + 'included in allowed_clouds in ~/.sky/config.yaml: ' + f'{", ".join([c for c in disallowed_cloud_names])}') if len(all_enabled_clouds) == 0: echo( click.style( @@ -96,6 +133,8 @@ def check_one_cloud( 'task. Run `sky check` for more info.', fg='red', bold=True)) + if disallowed_clouds_hint: + echo(click.style(disallowed_clouds_hint, dim=True)) raise SystemExit() else: clouds_arg = (' ' + @@ -109,6 +148,9 @@ def check_one_cloud( 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long dim=True)) + if disallowed_clouds_hint: + echo(click.style(disallowed_clouds_hint, dim=True)) + # Pretty print for UX. if not quiet: enabled_clouds_str = '\n :heavy_check_mark: '.join( diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 878fe67178e..42e0da96211 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -587,6 +587,7 @@ def get_default_remote_identity(cloud: str) -> str: def get_config_schema(): # pylint: disable=import-outside-toplevel + from sky.clouds import service_catalog from sky.utils import kubernetes_enums resources_schema = { @@ -722,6 +723,16 @@ def get_config_schema(): }, } + allowed_clouds = { + # A list of cloud names that are allowed to be used + 'type': 'array', + 'items': { + 'type': 'string', + 'case_insensitive_enum': + (list(service_catalog.ALL_CLOUDS) + ['cloudflare']) + } + } + for cloud, config in cloud_configs.items(): if cloud == 'aws': config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS) @@ -738,6 +749,7 @@ def get_config_schema(): 'jobs': controller_resources_schema, 'spot': controller_resources_schema, 'serve': controller_resources_schema, + 'allowed_clouds': allowed_clouds, **cloud_configs, }, # Avoid spot and jobs being present at the same time.