Skip to content

Commit

Permalink
[core] Add allowed_clouds to config to check only specific clouds (#…
Browse files Browse the repository at this point in the history
…3556)

* candidate_clouds

* Working allowed_clouds

* Working allowed_clouds

* comments

* lint

* change skipped clouds to disabled clouds

* lint
  • Loading branch information
romilbhardwaj authored May 17, 2024
1 parent d09b6dc commit 6968be5
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
13 changes: 13 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ Available fields and semantics:
cpus: 4+ # number of vCPUs, max concurrent spot jobs = 2 * cpus
disk_size: 100
# Allow list for clouds to be used in `sky check`
#
# This field is used to restrict the clouds that SkyPilot will check and use
# when running `sky check`. Any cloud already enabled but not specified here
# will be disabled on the next `sky check` run.
# If this field is not set, SkyPilot will check and use all supported clouds.
#
# Default: null (use all supported clouds).
allowed_clouds:
- aws
- gcp
- kubernetes
# Advanced AWS configurations (optional).
# Apply to all new instances but not existing ones.
aws:
Expand Down
1 change: 1 addition & 0 deletions sky/adaptors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
R2_PROFILE_NAME = 'r2'
_INDENT_PREFIX = ' '
NAME = 'Cloudflare'
SKY_CHECK_NAME = 'Cloudflare (for R2 object store)'


@contextlib.contextmanager
Expand Down
76 changes: 59 additions & 17 deletions sky/check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Credential checks: check cloud credentials and enable clouds."""
import traceback
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

import click
import colorama
Expand All @@ -10,6 +10,7 @@
from sky import clouds as sky_clouds
from sky import exceptions
from sky import global_user_state
from sky import skypilot_config
from sky.adaptors import cloudflare
from sky.utils import ux_utils

Expand Down Expand Up @@ -52,20 +53,42 @@ def check_one_cloud(
disabled_clouds.append(cloud_repr)
echo(f' Reason: {reason}')

def get_cloud_tuple(
cloud_name: str) -> Tuple[str, Union[sky_clouds.Cloud, ModuleType]]:
# Validates cloud_name and returns a tuple of the cloud's name and
# the cloud object. Includes special handling for Cloudflare.
if cloud_name.lower().startswith('cloudflare'):
return cloudflare.SKY_CHECK_NAME, cloudflare
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud_name)
assert cloud_obj is not None, f'Cloud {cloud_name!r} not found'
return repr(cloud_obj), cloud_obj

def get_all_clouds():
return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] +
[cloudflare.SKY_CHECK_NAME])

if clouds is not None:
clouds_to_check: List[Tuple[str, Any]] = []
for cloud in clouds:
if cloud.lower() == 'cloudflare':
clouds_to_check.append(
('Cloudflare, for R2 object store', cloudflare))
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud)
assert cloud_obj is not None, f'Cloud {cloud!r} not found'
clouds_to_check.append((repr(cloud_obj), cloud_obj))
cloud_list = clouds
else:
clouds_to_check = [(repr(cloud_obj), cloud_obj)
for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()]
clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare))
cloud_list = get_all_clouds()
clouds_to_check = [get_cloud_tuple(c) for c in cloud_list]

# Use allowed_clouds from config if it exists, otherwise check all clouds.
# Also validate names with get_cloud_tuple.
config_allowed_cloud_names = [
get_cloud_tuple(c)[0] for c in skypilot_config.get_nested(
['allowed_clouds'], get_all_clouds())
]
# Use disallowed_cloud_names for logging the clouds that will be disabled
# because they are not included in allowed_clouds in config.yaml.
disallowed_cloud_names = [
c for c in get_all_clouds() if c not in config_allowed_cloud_names
]
# Check only the clouds which are allowed in the config.
clouds_to_check = [
c for c in clouds_to_check if c[0] in config_allowed_cloud_names
]

for cloud_tuple in sorted(clouds_to_check):
check_one_cloud(cloud_tuple)
Expand All @@ -79,23 +102,39 @@ def check_one_cloud(
disabled_clouds_set = {
cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare')
}
config_allowed_clouds_set = {
cloud for cloud in config_allowed_cloud_names
if not cloud.startswith('Cloudflare')
}
previously_enabled_clouds_set = {
repr(cloud) for cloud in global_user_state.get_cached_enabled_clouds()
}

# Determine the set of enabled clouds: previously enabled clouds + newly
# enabled clouds - newly disabled clouds.
all_enabled_clouds = ((previously_enabled_clouds_set | enabled_clouds_set) -
disabled_clouds_set)
# Determine the set of enabled clouds: (previously enabled clouds + newly
# enabled clouds - newly disabled clouds) intersected with
# config_allowed_clouds, if specified in config.yaml.
# This means that if a cloud is already enabled and is not included in
# allowed_clouds in config.yaml, it will be disabled.
all_enabled_clouds = (config_allowed_clouds_set & (
(previously_enabled_clouds_set | enabled_clouds_set) -
disabled_clouds_set))
global_user_state.set_enabled_clouds(list(all_enabled_clouds))

disallowed_clouds_hint = None
if disallowed_cloud_names:
disallowed_clouds_hint = (
'\nNote: The following clouds were disabled because they were not '
'included in allowed_clouds in ~/.sky/config.yaml: '
f'{", ".join([c for c in disallowed_cloud_names])}')
if len(all_enabled_clouds) == 0:
echo(
click.style(
'No cloud is enabled. SkyPilot will not be able to run any '
'task. Run `sky check` for more info.',
fg='red',
bold=True))
if disallowed_clouds_hint:
echo(click.style(disallowed_clouds_hint, dim=True))
raise SystemExit()
else:
clouds_arg = (' ' +
Expand All @@ -109,6 +148,9 @@ def check_one_cloud(
'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long
dim=True))

if disallowed_clouds_hint:
echo(click.style(disallowed_clouds_hint, dim=True))

# Pretty print for UX.
if not quiet:
enabled_clouds_str = '\n :heavy_check_mark: '.join(
Expand Down
12 changes: 12 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def get_default_remote_identity(cloud: str) -> str:

def get_config_schema():
# pylint: disable=import-outside-toplevel
from sky.clouds import service_catalog
from sky.utils import kubernetes_enums

resources_schema = {
Expand Down Expand Up @@ -722,6 +723,16 @@ def get_config_schema():
},
}

allowed_clouds = {
# A list of cloud names that are allowed to be used
'type': 'array',
'items': {
'type': 'string',
'case_insensitive_enum':
(list(service_catalog.ALL_CLOUDS) + ['cloudflare'])
}
}

for cloud, config in cloud_configs.items():
if cloud == 'aws':
config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS)
Expand All @@ -738,6 +749,7 @@ def get_config_schema():
'jobs': controller_resources_schema,
'spot': controller_resources_schema,
'serve': controller_resources_schema,
'allowed_clouds': allowed_clouds,
**cloud_configs,
},
# Avoid spot and jobs being present at the same time.
Expand Down

0 comments on commit 6968be5

Please sign in to comment.