Skip to content

Commit

Permalink
Working allowed_clouds
Browse files Browse the repository at this point in the history
  • Loading branch information
romilbhardwaj committed May 17, 2024
1 parent 983856e commit 1f78982
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
1 change: 1 addition & 0 deletions sky/adaptors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
R2_PROFILE_NAME = 'r2'
_INDENT_PREFIX = ' '
NAME = 'Cloudflare'
SKY_CHECK_NAME = 'Cloudflare (for R2 object store)'


@contextlib.contextmanager
Expand Down
63 changes: 44 additions & 19 deletions sky/check.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Credential checks: check cloud credentials and enable clouds."""
import traceback
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

import click
import colorama
import rich

from sky import clouds as sky_clouds
from sky import skypilot_config
from sky import exceptions
from sky import global_user_state
from sky import skypilot_config
from sky.adaptors import cloudflare
from sky.utils import ux_utils

Expand Down Expand Up @@ -53,26 +53,40 @@ def check_one_cloud(
disabled_clouds.append(cloud_repr)
echo(f' Reason: {reason}')

# Use candidate_clouds from config if it exists, otherwise check all clouds.
config_candidate_clouds = skypilot_config.get_nested(['candidate_clouds'],
None)
# Validate config_candidate_clouds
config_candidate_clouds = [repr(sky_clouds.CLOUD_REGISTRY.from_str(c) for c in config_candidate_clouds)
def get_cloud_tuple(
cloud_name: str) -> Tuple[str, Union[sky_clouds.Cloud, ModuleType]]:
# Validates cloud_name and returns a tuple of the cloud's name and
# the cloud object. Includes special handling for Cloudflare.
if cloud_name.lower() == 'cloudflare':
return cloudflare.SKY_CHECK_NAME, cloudflare
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud_name)
assert cloud_obj is not None, f'Cloud {cloud_name!r} not found'
return repr(cloud_obj), cloud_obj

def get_all_clouds():
return tuple(list(sky_clouds.CLOUD_REGISTRY.keys()) + [cloudflare.NAME])

if clouds is not None:
clouds_to_check: List[Tuple[str, Any]] = []
for cloud in clouds:
if cloud.lower() == 'cloudflare':
clouds_to_check.append(
('Cloudflare, for R2 object store', cloudflare))
else:
cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud)
assert cloud_obj is not None, f'Cloud {cloud!r} not found'
clouds_to_check.append((repr(cloud_obj), cloud_obj))
cloud_list = clouds
else:
clouds_to_check = [(repr(cloud_obj), cloud_obj)
for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()]
clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare))
cloud_list = get_all_clouds()
clouds_to_check = [get_cloud_tuple(c) for c in cloud_list]

# Use allowed_clouds from config if it exists, otherwise check all clouds.
# Also validate names with get_cloud_tuple.
config_allowed_clouds = [
get_cloud_tuple(c)[0] for c in skypilot_config.get_nested(
['allowed_clouds'], get_all_clouds())
]
# Use skipped_clouds for logging the skipped clouds.
skipped_clouds = [
c for c in clouds_to_check if c[0] not in config_allowed_clouds
]
# Check only the clouds which are allowed in the config.
clouds_to_check = [
c for c in clouds_to_check if c[0] in config_allowed_clouds
]

for cloud_tuple in sorted(clouds_to_check):
check_one_cloud(cloud_tuple)
Expand All @@ -96,13 +110,21 @@ def check_one_cloud(
disabled_clouds_set)
global_user_state.set_enabled_clouds(list(all_enabled_clouds))

skipped_clouds_hint = None
if skipped_clouds:
skipped_clouds_hint = (
'\nNote: The following clouds were skipped because they were not '
'included in allowed_clouds in ~/.sky/config.yaml: '
f'{", ".join([c[0] for c in skipped_clouds])}')
if len(all_enabled_clouds) == 0:
echo(
click.style(
'No cloud is enabled. SkyPilot will not be able to run any '
'task. Run `sky check` for more info.',
fg='red',
bold=True))
if skipped_clouds_hint:
echo(click.style(skipped_clouds_hint, dim=True))
raise SystemExit()
else:
clouds_arg = (' ' +
Expand All @@ -116,6 +138,9 @@ def check_one_cloud(
'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long
dim=True))

if skipped_clouds_hint:
echo(click.style(skipped_clouds_hint, dim=True))

# Pretty print for UX.
if not quiet:
enabled_clouds_str = '\n :heavy_check_mark: '.join(
Expand Down
6 changes: 3 additions & 3 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ def get_config_schema():
},
}

candidate_clouds = {
# A list of cloud names that should be used for execution
allowed_clouds = {
# A list of cloud names that are allowed to be used
'type': 'array',
'items': {
'type': 'string',
Expand All @@ -746,7 +746,7 @@ def get_config_schema():
'jobs': controller_resources_schema,
'spot': controller_resources_schema,
'serve': controller_resources_schema,
'candidate_clouds': candidate_clouds,
'allowed_clouds': allowed_clouds,
**cloud_configs,
},
# Avoid spot and jobs being present at the same time.
Expand Down

0 comments on commit 1f78982

Please sign in to comment.