Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Add allowed_clouds to config to check only specific clouds #3556

Merged
merged 8 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ Available fields and semantics:
cpus: 4+ # number of vCPUs, max concurrent spot jobs = 2 * cpus
disk_size: 100

# Allow list for clouds to be used in `sky check`
#
# This field is used to restrict the clouds that SkyPilot will check and use
# when running `sky check`. Any cloud already enabled but not specified here
# will be disabled on the next `sky check` run.
# If this field is not set, SkyPilot will check and use all supported clouds.
#
# Default: null (use all supported clouds).
allowed_clouds:
- aws
- gcp
- kubernetes

# Advanced AWS configurations (optional).
# Apply to all new instances but not existing ones.
aws:
Expand Down
1 change: 1 addition & 0 deletions sky/adaptors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
R2_PROFILE_NAME = 'r2'
_INDENT_PREFIX = ' '
NAME = 'Cloudflare'
SKY_CHECK_NAME = 'Cloudflare (for R2 object store)'


@contextlib.contextmanager
Expand Down
76 changes: 59 additions & 17 deletions sky/check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Credential checks: check cloud credentials and enable clouds."""
import traceback
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

import click
import colorama
Expand All @@ -10,6 +10,7 @@
from sky import clouds as sky_clouds
from sky import exceptions
from sky import global_user_state
from sky import skypilot_config
from sky.adaptors import cloudflare
from sky.utils import ux_utils

Expand Down Expand Up @@ -52,20 +53,42 @@ def check_one_cloud(
disabled_clouds.append(cloud_repr)
echo(f' Reason: {reason}')

def get_cloud_tuple(
cloud_name: str) -> Tuple[str, Union[sky_clouds.Cloud, ModuleType]]:
# Validates cloud_name and returns a tuple of the cloud's name and
# the cloud object. Includes special handling for Cloudflare.
if cloud_name.lower().startswith('cloudflare'):
return cloudflare.SKY_CHECK_NAME, cloudflare
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud_name)
assert cloud_obj is not None, f'Cloud {cloud_name!r} not found'
return repr(cloud_obj), cloud_obj

def get_all_clouds():
return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] +
[cloudflare.SKY_CHECK_NAME])

if clouds is not None:
clouds_to_check: List[Tuple[str, Any]] = []
for cloud in clouds:
if cloud.lower() == 'cloudflare':
clouds_to_check.append(
('Cloudflare, for R2 object store', cloudflare))
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud)
assert cloud_obj is not None, f'Cloud {cloud!r} not found'
clouds_to_check.append((repr(cloud_obj), cloud_obj))
cloud_list = clouds
else:
clouds_to_check = [(repr(cloud_obj), cloud_obj)
for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()]
clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare))
cloud_list = get_all_clouds()
clouds_to_check = [get_cloud_tuple(c) for c in cloud_list]

# Use allowed_clouds from config if it exists, otherwise check all clouds.
# Also validate names with get_cloud_tuple.
config_allowed_cloud_names = [
get_cloud_tuple(c)[0] for c in skypilot_config.get_nested(
['allowed_clouds'], get_all_clouds())
]
# Use disallowed_cloud_names for logging the clouds that will be disabled
# because they are not included in allowed_clouds in config.yaml.
disallowed_cloud_names = [
c for c in get_all_clouds() if c not in config_allowed_cloud_names
]
# Check only the clouds which are allowed in the config.
clouds_to_check = [
c for c in clouds_to_check if c[0] in config_allowed_cloud_names
]

for cloud_tuple in sorted(clouds_to_check):
check_one_cloud(cloud_tuple)
Expand All @@ -79,23 +102,39 @@ def check_one_cloud(
disabled_clouds_set = {
cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare')
}
config_allowed_clouds_set = {
cloud for cloud in config_allowed_cloud_names
if not cloud.startswith('Cloudflare')
}
previously_enabled_clouds_set = {
repr(cloud) for cloud in global_user_state.get_cached_enabled_clouds()
}

# Determine the set of enabled clouds: previously enabled clouds + newly
# enabled clouds - newly disabled clouds.
all_enabled_clouds = ((previously_enabled_clouds_set | enabled_clouds_set) -
disabled_clouds_set)
# Determine the set of enabled clouds: (previously enabled clouds + newly
# enabled clouds - newly disabled clouds) intersected with
# config_allowed_clouds, if specified in config.yaml.
# This means that if a cloud is already enabled and is not included in
# allowed_clouds in config.yaml, it will be disabled.
all_enabled_clouds = (config_allowed_clouds_set & (
(previously_enabled_clouds_set | enabled_clouds_set) -
disabled_clouds_set))
global_user_state.set_enabled_clouds(list(all_enabled_clouds))

disallowed_clouds_hint = None
if disallowed_cloud_names:
disallowed_clouds_hint = (
'\nNote: The following clouds were disabled because they were not '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now changed the wording to disabled, we may want to always show the hint even if sky check gcp is used, as we will show the enabled clouds with the old enabled clouds included as well. : )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm do you mean even when allowed_clouds is not specified in config.yaml, sky check <cloud> should always show the hint above?

I think in that case hint should not be shown because those clouds were not disabled because they were not in allowed_list, but rather because of other errors while checking (which are surfaced as <cloud>: disabled along with reason). wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I mean when the allowed_clouds is specified and sky check <cloud> is used. It seems there is no hint showing in that case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a ~/.sky/config.yaml, and when I do sky check gcp the following shows up, without the hint:

allowed_clouds: ['aws', 'gcp', 'cloudflare']
  GCP: enabled                              

To enable a cloud, follow the hints above and rerun: sky check 
If any problems remain, refer to detailed docs at: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html

🎉 Enabled clouds 🎉
  ✔ AWS
  ✔ GCP

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit which should not block us merging the PR :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh got it! It should be fixed now :)

'included in allowed_clouds in ~/.sky/config.yaml: '
f'{", ".join([c for c in disallowed_cloud_names])}')
if len(all_enabled_clouds) == 0:
echo(
click.style(
'No cloud is enabled. SkyPilot will not be able to run any '
'task. Run `sky check` for more info.',
fg='red',
bold=True))
if disallowed_clouds_hint:
echo(click.style(disallowed_clouds_hint, dim=True))
raise SystemExit()
else:
clouds_arg = (' ' +
Expand All @@ -109,6 +148,9 @@ def check_one_cloud(
'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long
dim=True))

if disallowed_clouds_hint:
echo(click.style(disallowed_clouds_hint, dim=True))

# Pretty print for UX.
if not quiet:
enabled_clouds_str = '\n :heavy_check_mark: '.join(
Expand Down
12 changes: 12 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def get_default_remote_identity(cloud: str) -> str:

def get_config_schema():
# pylint: disable=import-outside-toplevel
from sky.clouds import service_catalog
from sky.utils import kubernetes_enums

resources_schema = {
Expand Down Expand Up @@ -722,6 +723,16 @@ def get_config_schema():
},
}

allowed_clouds = {
# A list of cloud names that are allowed to be used
'type': 'array',
'items': {
'type': 'string',
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
'case_insensitive_enum':
(list(service_catalog.ALL_CLOUDS) + ['cloudflare'])
}
}

for cloud, config in cloud_configs.items():
if cloud == 'aws':
config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS)
Expand All @@ -738,6 +749,7 @@ def get_config_schema():
'jobs': controller_resources_schema,
'spot': controller_resources_schema,
'serve': controller_resources_schema,
'allowed_clouds': allowed_clouds,
**cloud_configs,
},
# Avoid spot and jobs being present at the same time.
Expand Down
Loading