diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 39045962a78..1b2c55668cb 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -465,20 +465,17 @@ def get_controller_resources( if handle is not None: controller_resources_to_use = handle.launched_resources - if controller_resources_to_use.cloud is not None: - return {controller_resources_to_use} + # If the controller and replicas are from the same cloud (and region/zone), + # it should provide better connectivity. We will let the controller choose + # from the clouds (and regions/zones) of the resources if the user does not + # specify the cloud (and region/zone) for the controller. - # If the controller and replicas are from the same cloud, it should - # provide better connectivity. We will let the controller choose from - # the clouds of the resources if the controller does not exist. - # TODO(tian): Consider respecting the regions/zones specified for the - # resources as well. - requested_clouds: Set['clouds.Cloud'] = set() + requested_clouds_with_region_zone: Dict[str, Dict[Optional[str], + Set[Optional[str]]]] = {} for resource in task_resources: - # cloud is an object and will not be able to be distinguished by set. - # Here we manually check if the cloud is in the set. - if resource.cloud is not None: - if not clouds.cloud_in_iterable(resource.cloud, requested_clouds): + cloud_name = str(resource.cloud) if resource.cloud is not None else None + if cloud_name is not None: + if cloud_name not in requested_clouds_with_region_zone: try: resource.cloud.check_features_are_supported( resources.Resources(), @@ -486,7 +483,26 @@ def get_controller_resources( except exceptions.NotSupportedError: # Skip the cloud if it does not support hosting controllers. continue - requested_clouds.add(resource.cloud) + requested_clouds_with_region_zone[cloud_name] = {} + if resource.region is None: + # If one of the resource.region is None, this could represent + # that the user is unsure about which region the resource is + # hosted in. In this case, we allow any region for this cloud. + requested_clouds_with_region_zone[cloud_name] = {None: {None}} + elif None not in requested_clouds_with_region_zone[cloud_name]: + if resource.region not in requested_clouds_with_region_zone[ + cloud_name]: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = set() + # If one of the resource.zone is None, allow any zone in the + # region. + if resource.zone is None: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = {None} + elif None not in requested_clouds_with_region_zone[cloud_name][ + resource.region]: + requested_clouds_with_region_zone[cloud_name][ + resource.region].add(resource.zone) else: # if one of the resource.cloud is None, this could represent user # does not know which cloud is best for the specified resources. @@ -496,14 +512,48 @@ def get_controller_resources( # - cloud: runpod # accelerators: A40 # In this case, we allow the controller to be launched on any cloud. - requested_clouds.clear() + requested_clouds_with_region_zone.clear() break - if not requested_clouds: + + # Extract filtering criteria from controller_resources_to_use + controller_cloud = str( + controller_resources_to_use.cloud + ) if controller_resources_to_use.cloud is not None else None + controller_region = controller_resources_to_use.region + controller_zone = controller_resources_to_use.zone + + # Filter clouds if controller_resources_to_use.cloud is specified + filtered_clouds = ({controller_cloud} if controller_cloud else + requested_clouds_with_region_zone.keys()) + + # Filter regions and zones and construct the result + result = set() + for cloud_name in filtered_clouds: + regions = requested_clouds_with_region_zone.get(cloud_name, + {None: {None}}) + + # Filter regions if controller_resources_to_use.region is specified + filtered_regions = ({controller_region} + if controller_region else regions.keys()) + + for region in filtered_regions: + zones = regions.get(region, {None}) + + # Filter zones if controller_resources_to_use.zone is specified + filtered_zones = ({controller_zone} if controller_zone else zones) + + # Create combinations of cloud, region, and zone + for zone in filtered_zones: + resource_copy = controller_resources_to_use.copy( + cloud=clouds.CLOUD_REGISTRY.from_str(cloud_name), + region=region, + zone=zone) + result.add(resource_copy) + + if result: + return result + else: return {controller_resources_to_use} - return { - controller_resources_to_use.copy(cloud=controller_cloud) - for controller_cloud in requested_clouds - } def _setup_proxy_command_on_controller( diff --git a/tests/unit_tests/test_controller_utils.py b/tests/unit_tests/test_controller_utils.py index 7465f648385..2ea9afddb6b 100644 --- a/tests/unit_tests/test_controller_utils.py +++ b/tests/unit_tests/test_controller_utils.py @@ -65,6 +65,20 @@ def get_custom_controller_resources(keys, default): controller_resources_config, k, v) +def _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources): + """Helper function to check that the controller resources match the expected combinations.""" + for r in controller_resources: + config = r.to_yaml_config() + cloud = config.pop('cloud') + region = config.pop('region', None) + zone = config.pop('zone', None) + assert (cloud, region, zone) in expected_combinations + expected_combinations.remove((cloud, region, zone)) + assert config == default_controller_resources, config + assert not expected_combinations + + @pytest.mark.parametrize(('controller_type', 'default_controller_resources'), [ ('jobs', managed_job_constants.CONTROLLER_RESOURCES), ('serve', serve_constants.CONTROLLER_RESOURCES), @@ -138,3 +152,72 @@ def _could_host_controllers(cloud: sky.clouds.Cloud) -> bool: assert len(controller_resources) == 1 config = list(controller_resources)[0].to_yaml_config() assert config == default_controller_resources, config + + # 4. All resources have clouds, regions, and zones specified. + # Return a set of controller resources for all combinations of clouds, + # regions, and zones. Each combination should contain the default resources + # along with the cloud, region, and zone. + all_cloud_regions_zones = [ + sky.Resources(cloud=sky.AWS(), region='us-east-1', zone='us-east-1a'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1', zone='ap-south-1b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a'), + sky.Resources(cloud=sky.GCP(), + region='europe-west1', + zone='europe-west1-b') + ] + expected_combinations = {('AWS', 'us-east-1', 'us-east-1a'), + ('AWS', 'ap-south-1', 'ap-south-1b'), + ('GCP', 'us-central1', 'us-central1-a'), + ('GCP', 'europe-west1', 'europe-west1-b')} + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=all_cloud_regions_zones) + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 5. Clouds and regions are specified, but zones are partially specified. + # Return a set containing combinations where the zone is None + # when not all zones are specified in the input for the given region. The default + # resources should be returned along with the cloud and region, and the zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.AWS(), region='us-west-2'), + sky.Resources(cloud=sky.AWS(), + region='us-west-2', + zone='us-west-2b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a') + ]) + expected_combinations = {('AWS', 'us-west-2', None), + ('GCP', 'us-central1', 'us-central1-a')} + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 6. Mixed case: Some resources have clouds and regions or zones, others do not. + # For clouds where regions or zones are not specified in the input, return None + # for those fields. The default resources should be returned along with the cloud, + # region (if specified), and zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.GCP(), region='europe-west1'), + sky.Resources(cloud=sky.GCP()), + sky.Resources(cloud=sky.AWS(), + region='eu-north-1', + zone='eu-north-1a'), + sky.Resources(cloud=sky.AWS(), region='eu-north-1'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1'), + sky.Resources(cloud=sky.Azure()), + ]) + expected_combinations = { + ('AWS', 'eu-north-1', None), + ('AWS', 'ap-south-1', None), + ('GCP', None, None), + ('Azure', None, None), + } + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources)