diff --git a/sky/__init__.py b/sky/__init__.py index 5a47c6a51bb..d25c8297ea5 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -90,8 +90,6 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.core import download_logs from sky.core import job_status from sky.core import queue -from sky.core import spot_cancel -from sky.core import spot_queue from sky.core import start from sky.core import status from sky.core import stop @@ -104,11 +102,16 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.data import StoreType from sky.execution import exec # pylint: disable=redefined-builtin from sky.execution import launch -from sky.execution import spot_launch from sky.optimizer import Optimizer from sky.optimizer import OptimizeTarget from sky.resources import Resources from sky.skylet.job_lib import JobStatus +# TODO (zhwu): These imports are for backward compatibility, and spot APIs +# should be called with `sky.spot.xxx` instead. Remove in release 0.7.0 +from sky.spot.core import spot_cancel +from sky.spot.core import spot_launch +from sky.spot.core import spot_queue +from sky.spot.core import spot_tail_logs from sky.status_lib import ClusterStatus from sky.task import Task @@ -173,6 +176,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'queue', 'cancel', 'tail_logs', + 'spot_tail_logs', 'download_logs', 'job_status', # core APIs Spot Job Management diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 7a6a2ab33cc..0da2bd9ef0b 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2292,7 +2292,7 @@ def is_controller_accessible( # will not start the controller manually from the cloud console. # # The acquire_lock_timeout is set to 0 to avoid hanging the command when - # multiple spot_launch commands are running at the same time. Our later + # multiple spot.launch commands are running at the same time. Our later # code will check if the controller is accessible by directly checking # the ssh connection to the controller, if it fails to get accurate # status of the controller. diff --git a/sky/cli.py b/sky/cli.py index a26d529ce10..72667cffc97 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -1250,8 +1250,8 @@ def _get_spot_jobs( usage_lib.messages.usage.set_internal() with sky_logging.silent(): # Make the call silent - spot_jobs = core.spot_queue(refresh=refresh, - skip_finished=skip_finished) + spot_jobs = spot_lib.queue(refresh=refresh, + skip_finished=skip_finished) num_in_progress_jobs = len(spot_jobs) except exceptions.ClusterNotUpError as e: controller_status = e.cluster_status @@ -2508,7 +2508,7 @@ def _hint_or_raise_for_down_spot_controller(controller_name: str): with rich_utils.safe_status( '[bold cyan]Checking for in-progress spot jobs[/]'): try: - spot_jobs = core.spot_queue(refresh=False, skip_finished=True) + spot_jobs = spot_lib.queue(refresh=False, skip_finished=True) except exceptions.ClusterNotUpError as e: if controller.value.connection_error_hint in str(e): with ux_utils.print_exception_no_traceback(): @@ -3289,7 +3289,7 @@ def spot_launch( common_utils.check_cluster_name_is_valid(name) - sky.spot_launch(dag, + spot_lib.launch(dag, name, detach_run=detach_run, retry_until_up=retry_until_up) @@ -3448,7 +3448,7 @@ def spot_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): abort=True, show_default=True) - core.spot_cancel(job_ids=job_ids, name=name, all=all) + spot_lib.cancel(job_ids=job_ids, name=name, all=all) @spot.command('logs', cls=_DocumentedCodeCommand) @@ -3480,7 +3480,7 @@ def spot_logs(name: Optional[str], job_id: Optional[int], follow: bool, job_id=job_id, follow=follow) else: - core.spot_tail_logs(name=name, job_id=job_id, follow=follow) + spot_lib.tail_logs(name=name, job_id=job_id, follow=follow) except exceptions.ClusterNotUpError as e: click.echo(e) sys.exit(1) diff --git a/sky/core.py b/sky/core.py index 2736c3d7c5f..c93a50f0b7d 100644 --- a/sky/core.py +++ b/sky/core.py @@ -12,7 +12,6 @@ from sky import exceptions from sky import global_user_state from sky import sky_logging -from sky import spot from sky import status_lib from sky import task from sky.backends import backend_utils @@ -20,9 +19,7 @@ from sky.skylet import job_lib from sky.usage import usage_lib from sky.utils import controller_utils -from sky.utils import rich_utils from sky.utils import subprocess_utils -from sky.utils import ux_utils if typing.TYPE_CHECKING: from sky import resources as resources_lib @@ -230,7 +227,7 @@ def start( retry_until_up: bool = False, down: bool = False, # pylint: disable=redefined-outer-name force: bool = False, -) -> None: +) -> backends.CloudVmRayResourceHandle: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Restart a cluster. @@ -279,11 +276,11 @@ def start( if down and idle_minutes_to_autostop is None: raise ValueError( '`idle_minutes_to_autostop` must be set if `down` is True.') - _start(cluster_name, - idle_minutes_to_autostop, - retry_until_up, - down, - force=force) + return _start(cluster_name, + idle_minutes_to_autostop, + retry_until_up, + down, + force=force) def _stop_not_supported_message(resources: 'resources_lib.Resources') -> str: @@ -762,178 +759,6 @@ def job_status(cluster_name: str, return statuses -# ======================= -# = Spot Job Management = -# ======================= - - -@usage_lib.entrypoint -def spot_queue(refresh: bool, - skip_finished: bool = False) -> List[Dict[str, Any]]: - # NOTE(dev): Keep the docstring consistent between the Python API and CLI. - """Get statuses of managed spot jobs. - - Please refer to the sky.cli.spot_queue for the documentation. - - Returns: - [ - { - 'job_id': int, - 'job_name': str, - 'resources': str, - 'submitted_at': (float) timestamp of submission, - 'end_at': (float) timestamp of end, - 'duration': (float) duration in seconds, - 'recovery_count': (int) Number of retries, - 'status': (sky.spot.SpotStatus) of the job, - 'cluster_resources': (str) resources of the cluster, - 'region': (str) region of the cluster, - } - ] - Raises: - sky.exceptions.ClusterNotUpError: the spot controller is not up or - does not exist. - RuntimeError: if failed to get the spot jobs with ssh. - """ - stopped_message = '' - if not refresh: - stopped_message = ('No in-progress spot jobs.') - try: - handle = backend_utils.is_controller_accessible( - controller_type=controller_utils.Controllers.SPOT_CONTROLLER, - stopped_message=stopped_message) - except exceptions.ClusterNotUpError as e: - if not refresh: - raise - handle = None - controller_status = e.cluster_status - - if refresh and handle is None: - sky_logging.print(f'{colorama.Fore.YELLOW}' - 'Restarting controller for latest status...' - f'{colorama.Style.RESET_ALL}') - - rich_utils.force_update_status('[cyan] Checking spot jobs - restarting ' - 'controller[/]') - handle = _start(spot.SPOT_CONTROLLER_NAME) - controller_status = status_lib.ClusterStatus.UP - rich_utils.force_update_status('[cyan] Checking spot jobs[/]') - - assert handle is not None, (controller_status, refresh) - - backend = backend_utils.get_backend_from_handle(handle) - assert isinstance(backend, backends.CloudVmRayBackend) - - code = spot.SpotCodeGen.get_job_table() - returncode, job_table_payload, stderr = backend.run_on_head( - handle, - code, - require_outputs=True, - stream_logs=False, - separate_stderr=True) - - try: - subprocess_utils.handle_returncode(returncode, - code, - 'Failed to fetch managed spot jobs', - job_table_payload + stderr, - stream_logs=False) - except exceptions.CommandError as e: - raise RuntimeError(str(e)) from e - - jobs = spot.load_spot_job_queue(job_table_payload) - if skip_finished: - # Filter out the finished jobs. If a multi-task job is partially - # finished, we will include all its tasks. - non_finished_tasks = list( - filter(lambda job: not job['status'].is_terminal(), jobs)) - non_finished_job_ids = {job['job_id'] for job in non_finished_tasks} - jobs = list( - filter(lambda job: job['job_id'] in non_finished_job_ids, jobs)) - return jobs - - -@usage_lib.entrypoint -# pylint: disable=redefined-builtin -def spot_cancel(name: Optional[str] = None, - job_ids: Optional[List[int]] = None, - all: bool = False) -> None: - # NOTE(dev): Keep the docstring consistent between the Python API and CLI. - """Cancel managed spot jobs. - - Please refer to the sky.cli.spot_cancel for the document. - - Raises: - sky.exceptions.ClusterNotUpError: the spot controller is not up. - RuntimeError: failed to cancel the job. - """ - job_ids = [] if job_ids is None else job_ids - handle = backend_utils.is_controller_accessible( - controller_type=controller_utils.Controllers.SPOT_CONTROLLER, - stopped_message='All managed spot jobs should have finished.') - - job_id_str = ','.join(map(str, job_ids)) - if sum([len(job_ids) > 0, name is not None, all]) != 1: - argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else '' - argument_str += f' name={name}' if name is not None else '' - argument_str += ' all' if all else '' - raise ValueError('Can only specify one of JOB_IDS or name or all. ' - f'Provided {argument_str!r}.') - - backend = backend_utils.get_backend_from_handle(handle) - assert isinstance(backend, backends.CloudVmRayBackend) - if all: - code = spot.SpotCodeGen.cancel_jobs_by_id(None) - elif job_ids: - code = spot.SpotCodeGen.cancel_jobs_by_id(job_ids) - else: - assert name is not None, (job_ids, name, all) - code = spot.SpotCodeGen.cancel_job_by_name(name) - # The stderr is redirected to stdout - returncode, stdout, _ = backend.run_on_head(handle, - code, - require_outputs=True, - stream_logs=False) - try: - subprocess_utils.handle_returncode(returncode, code, - 'Failed to cancel managed spot job', - stdout) - except exceptions.CommandError as e: - raise RuntimeError(e.error_msg) from e - - sky_logging.print(stdout) - if 'Multiple jobs found with name' in stdout: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError( - 'Please specify the job ID instead of the job name.') - - -@usage_lib.entrypoint -def spot_tail_logs(name: Optional[str], job_id: Optional[int], - follow: bool) -> None: - # NOTE(dev): Keep the docstring consistent between the Python API and CLI. - """Tail logs of managed spot jobs. - - Please refer to the sky.cli.spot_logs for the document. - - Raises: - ValueError: invalid arguments. - sky.exceptions.ClusterNotUpError: the spot controller is not up. - """ - # TODO(zhwu): Automatically restart the spot controller - handle = backend_utils.is_controller_accessible( - controller_type=controller_utils.Controllers.SPOT_CONTROLLER, - stopped_message=('Please restart the spot controller with ' - f'`sky start {spot.SPOT_CONTROLLER_NAME}`.')) - - if name is not None and job_id is not None: - raise ValueError('Cannot specify both name and job_id.') - backend = backend_utils.get_backend_from_handle(handle) - assert isinstance(backend, backends.CloudVmRayBackend), backend - # Stream the realtime logs - backend.tail_spot_logs(handle, job_id=job_id, job_name=name, follow=follow) - - # ====================== # = Storage Management = # ====================== diff --git a/sky/execution.py b/sky/execution.py index 9743ca1e6a6..25f0d8cc7a8 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -2,12 +2,9 @@ See `Stage` for a Task's life cycle. """ -import copy import enum import os -import tempfile -from typing import Any, List, Optional, Tuple, Union -import uuid +from typing import List, Optional, Tuple, Union import colorama @@ -17,13 +14,8 @@ from sky import global_user_state from sky import optimizer from sky import sky_logging -from sky import spot -from sky import task as task_lib from sky.backends import backend_utils -from sky.clouds.service_catalog import common as service_catalog_common -from sky.skylet import constants from sky.usage import usage_lib -from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils from sky.utils import env_options @@ -34,52 +26,6 @@ logger = sky_logging.init_logger(__name__) -# Message thrown when APIs sky.{exec,launch,spot_launch}() received a string -# instead of a Dag. CLI (cli.py) is implemented by us so should not trigger -# this. -_ENTRYPOINT_STRING_AS_DAG_MESSAGE = """\ -Expected a sky.Task or sky.Dag but received a string. - -If you meant to run a command, make it a Task's run command: - - task = sky.Task(run=command) - -The command can then be run as: - - sky.exec(task, cluster_name=..., ...) - # Or use {'V100': 1}, 'V100:0.5', etc. - task.set_resources(sky.Resources(accelerators='V100:1')) - sky.exec(task, cluster_name=..., ...) - - sky.launch(task, ...) - - sky.spot_launch(task, ...) -""".strip() - - -def _convert_to_dag(entrypoint: Any) -> 'sky.Dag': - """Convert the entrypoint to a sky.Dag. - - Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. - """ - # Not suppressing stacktrace: when calling this via API user may want to - # see their own program in the stacktrace. Our CLI impl would not trigger - # these errors. - if isinstance(entrypoint, str): - raise TypeError(_ENTRYPOINT_STRING_AS_DAG_MESSAGE) - elif isinstance(entrypoint, sky.Dag): - return copy.deepcopy(entrypoint) - elif isinstance(entrypoint, task_lib.Task): - entrypoint = copy.deepcopy(entrypoint) - with sky.Dag() as dag: - dag.add(entrypoint) - dag.name = entrypoint.name - return dag - else: - raise TypeError( - 'Expected a sky.Task or sky.Dag but received argument of type: ' - f'{type(entrypoint)}') - class Stage(enum.Enum): """Stages for a run of a sky.Task.""" @@ -210,7 +156,7 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ - dag = _convert_to_dag(entrypoint) + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) assert len(dag) == 1, f'We support 1 task for now. {dag}' task = dag.tasks[0] @@ -292,7 +238,7 @@ def _execute( f'automatically recover from preemptions.{reset}\n{yellow}To ' 'get automatic recovery, use managed spot instead: ' f'{reset}{bold}sky spot launch{reset} {yellow}or{reset} ' - f'{bold}sky.spot_launch(){reset}.') + f'{bold}sky.spot.launch(){reset}.') if Stage.OPTIMIZE in stages: if task.best_resources is None: @@ -614,112 +560,3 @@ def exec( # pylint: disable=redefined-builtin cluster_name=cluster_name, detach_run=detach_run, ) - - -@usage_lib.entrypoint -def spot_launch( - task: Union['sky.Task', 'sky.Dag'], - name: Optional[str] = None, - stream_logs: bool = True, - detach_run: bool = False, - retry_until_up: bool = False, -): - # NOTE(dev): Keep the docstring consistent between the Python API and CLI. - """Launch a managed spot job. - - Please refer to the sky.cli.spot_launch for the document. - - Args: - task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a - managed spot job. - name: Name of the spot job. - detach_run: Whether to detach the run. - - Raises: - ValueError: cluster does not exist. - sky.exceptions.NotSupportedError: the feature is not supported. - """ - entrypoint = task - dag_uuid = str(uuid.uuid4().hex[:4]) - - dag = _convert_to_dag(entrypoint) - assert dag.is_chain(), ('Only single-task or chain DAG is ' - 'allowed for spot_launch.', dag) - - dag_utils.maybe_infer_and_fill_dag_and_task_names(dag) - - task_names = set() - for task_ in dag.tasks: - if task_.name in task_names: - raise ValueError( - f'Task name {task_.name!r} is duplicated in the DAG. Either ' - 'change task names to be unique, or specify the DAG name only ' - 'and comment out the task names (so that they will be auto-' - 'generated) .') - task_names.add(task_.name) - - dag_utils.fill_default_spot_config_in_dag_for_spot_launch(dag) - - for task_ in dag.tasks: - controller_utils.maybe_translate_local_file_mounts_and_sync_up( - task_, path='spot') - - with tempfile.NamedTemporaryFile(prefix=f'spot-dag-{dag.name}-', - mode='w') as f: - dag_utils.dump_chain_dag_to_yaml(dag, f.name) - controller_name = spot.SPOT_CONTROLLER_NAME - prefix = spot.SPOT_TASK_YAML_PREFIX - remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml' - remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml' - controller_resources = (controller_utils.get_controller_resources( - controller_type='spot', - controller_resources_config=spot.constants.CONTROLLER_RESOURCES)) - - vars_to_fill = { - 'remote_user_yaml_path': remote_user_yaml_path, - 'user_yaml_path': f.name, - 'spot_controller': controller_name, - # Note: actual spot cluster name will be - - 'dag_name': dag.name, - 'retry_until_up': retry_until_up, - 'remote_user_config_path': remote_user_config_path, - 'sky_python_cmd': constants.SKY_PYTHON_CMD, - 'modified_catalogs': - service_catalog_common.get_modified_catalog_file_mounts(), - **controller_utils.shared_controller_vars_to_fill( - 'spot', - remote_user_config_path=remote_user_config_path, - ), - } - - yaml_path = os.path.join(spot.SPOT_CONTROLLER_YAML_PREFIX, - f'{name}-{dag_uuid}.yaml') - common_utils.fill_template(spot.SPOT_CONTROLLER_TEMPLATE, - vars_to_fill, - output_path=yaml_path) - controller_task = task_lib.Task.from_yaml(yaml_path) - assert len(controller_task.resources) == 1, controller_task - # Backward compatibility: if the user changed the - # spot-controller.yaml.j2 to customize the controller resources, - # we should use it. - controller_task_resources = list(controller_task.resources)[0] - if not controller_task_resources.is_empty(): - controller_resources = controller_task_resources - controller_task.set_resources(controller_resources) - - controller_task.spot_dag = dag - assert len(controller_task.resources) == 1 - - print(f'{colorama.Fore.YELLOW}' - f'Launching managed spot job {dag.name!r} from spot controller...' - f'{colorama.Style.RESET_ALL}') - print('Launching spot controller...') - _execute( - entrypoint=controller_task, - stream_logs=stream_logs, - cluster_name=controller_name, - detach_run=detach_run, - idle_minutes_to_autostop=constants. - CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP, - retry_until_up=True, - ) diff --git a/sky/spot/__init__.py b/sky/spot/__init__.py index f45b9ad648c..e25f25c9176 100644 --- a/sky/spot/__init__.py +++ b/sky/spot/__init__.py @@ -5,6 +5,10 @@ from sky.spot.constants import SPOT_CONTROLLER_TEMPLATE from sky.spot.constants import SPOT_CONTROLLER_YAML_PREFIX from sky.spot.constants import SPOT_TASK_YAML_PREFIX +from sky.spot.core import cancel +from sky.spot.core import launch +from sky.spot.core import queue +from sky.spot.core import tail_logs from sky.spot.recovery_strategy import SPOT_DEFAULT_STRATEGY from sky.spot.recovery_strategy import SPOT_STRATEGIES from sky.spot.spot_state import SpotStatus @@ -28,6 +32,11 @@ 'SPOT_TASK_YAML_PREFIX', # Enums 'SpotStatus', + # Core + 'cancel', + 'launch', + 'queue', + 'tail_logs', # utils 'SpotCodeGen', 'dump_job_table_cache', diff --git a/sky/spot/core.py b/sky/spot/core.py new file mode 100644 index 00000000000..459f351b7b4 --- /dev/null +++ b/sky/spot/core.py @@ -0,0 +1,317 @@ +"""SDK functions for managed spot job.""" +import os +import tempfile +from typing import Any, Dict, List, Optional, Union +import uuid + +import colorama + +import sky +from sky import backends +from sky import exceptions +from sky import sky_logging +from sky import status_lib +from sky import task as task_lib +from sky.backends import backend_utils +from sky.clouds.service_catalog import common as service_catalog_common +from sky.skylet import constants as skylet_constants +from sky.spot import constants +from sky.spot import spot_utils +from sky.usage import usage_lib +from sky.utils import common_utils +from sky.utils import controller_utils +from sky.utils import dag_utils +from sky.utils import rich_utils +from sky.utils import subprocess_utils +from sky.utils import ux_utils + + +@usage_lib.entrypoint +def launch( + task: Union['sky.Task', 'sky.Dag'], + name: Optional[str] = None, + stream_logs: bool = True, + detach_run: bool = False, + retry_until_up: bool = False, +) -> None: + # NOTE(dev): Keep the docstring consistent between the Python API and CLI. + """Launch a managed spot job. + + Please refer to the sky.cli.spot_launch for the document. + + Args: + task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a + managed spot job. + name: Name of the spot job. + detach_run: Whether to detach the run. + + Raises: + ValueError: cluster does not exist. + sky.exceptions.NotSupportedError: the feature is not supported. + """ + entrypoint = task + dag_uuid = str(uuid.uuid4().hex[:4]) + + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + if not dag.is_chain(): + with ux_utils.print_exception_no_traceback(): + raise ValueError('Only single-task or chain DAG is allowed for ' + f'sky.spot.launch. Dag:\n{dag}') + + dag_utils.maybe_infer_and_fill_dag_and_task_names(dag) + + task_names = set() + for task_ in dag.tasks: + if task_.name in task_names: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Task name {task_.name!r} is duplicated in the DAG. ' + 'Either change task names to be unique, or specify the DAG ' + 'name only and comment out the task names (so that they ' + 'will be auto-generated) .') + task_names.add(task_.name) + + dag_utils.fill_default_spot_config_in_dag_for_spot_launch(dag) + + for task_ in dag.tasks: + controller_utils.maybe_translate_local_file_mounts_and_sync_up( + task_, path='spot') + + with tempfile.NamedTemporaryFile(prefix=f'spot-dag-{dag.name}-', + mode='w') as f: + dag_utils.dump_chain_dag_to_yaml(dag, f.name) + controller_name = spot_utils.SPOT_CONTROLLER_NAME + prefix = constants.SPOT_TASK_YAML_PREFIX + remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml' + remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml' + controller_resources = controller_utils.get_controller_resources( + controller_type='spot', + controller_resources_config=constants.CONTROLLER_RESOURCES) + + vars_to_fill = { + 'remote_user_yaml_path': remote_user_yaml_path, + 'user_yaml_path': f.name, + 'spot_controller': controller_name, + # Note: actual spot cluster name will be - + 'dag_name': dag.name, + 'retry_until_up': retry_until_up, + 'remote_user_config_path': remote_user_config_path, + 'sky_python_cmd': skylet_constants.SKY_PYTHON_CMD, + 'modified_catalogs': + service_catalog_common.get_modified_catalog_file_mounts(), + **controller_utils.shared_controller_vars_to_fill( + 'spot', + remote_user_config_path=remote_user_config_path, + ), + } + + yaml_path = os.path.join(constants.SPOT_CONTROLLER_YAML_PREFIX, + f'{name}-{dag_uuid}.yaml') + common_utils.fill_template(constants.SPOT_CONTROLLER_TEMPLATE, + vars_to_fill, + output_path=yaml_path) + controller_task = task_lib.Task.from_yaml(yaml_path) + controller_task.set_resources(controller_resources) + + controller_task.spot_dag = dag + assert len(controller_task.resources) == 1 + + sky_logging.print( + f'{colorama.Fore.YELLOW}' + f'Launching managed spot job {dag.name!r} from spot controller...' + f'{colorama.Style.RESET_ALL}') + sky_logging.print('Launching spot controller...') + sky.launch(task=controller_task, + stream_logs=stream_logs, + cluster_name=controller_name, + detach_run=detach_run, + idle_minutes_to_autostop=skylet_constants. + CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP, + retry_until_up=True, + _disable_controller_check=True) + + +@usage_lib.entrypoint +def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: + # NOTE(dev): Keep the docstring consistent between the Python API and CLI. + """Get statuses of managed spot jobs. + + Please refer to the sky.cli.spot_queue for the documentation. + + Returns: + [ + { + 'job_id': int, + 'job_name': str, + 'resources': str, + 'submitted_at': (float) timestamp of submission, + 'end_at': (float) timestamp of end, + 'duration': (float) duration in seconds, + 'recovery_count': (int) Number of retries, + 'status': (sky.spot.SpotStatus) of the job, + 'cluster_resources': (str) resources of the cluster, + 'region': (str) region of the cluster, + } + ] + Raises: + sky.exceptions.ClusterNotUpError: the spot controller is not up or + does not exist. + RuntimeError: if failed to get the spot jobs with ssh. + """ + stopped_message = '' + if not refresh: + stopped_message = 'No in-progress spot jobs.' + try: + handle = backend_utils.is_controller_accessible( + controller_type=controller_utils.Controllers.SPOT_CONTROLLER, + stopped_message=stopped_message) + except exceptions.ClusterNotUpError as e: + if not refresh: + raise + handle = None + controller_status = e.cluster_status + + if refresh and handle is None: + sky_logging.print(f'{colorama.Fore.YELLOW}' + 'Restarting controller for latest status...' + f'{colorama.Style.RESET_ALL}') + + rich_utils.force_update_status('[cyan] Checking spot jobs - restarting ' + 'controller[/]') + handle = sky.start(spot_utils.SPOT_CONTROLLER_NAME) + controller_status = status_lib.ClusterStatus.UP + rich_utils.force_update_status('[cyan] Checking spot jobs[/]') + + assert handle is not None, (controller_status, refresh) + + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend) + + code = spot_utils.SpotCodeGen.get_job_table() + returncode, job_table_payload, stderr = backend.run_on_head( + handle, + code, + require_outputs=True, + stream_logs=False, + separate_stderr=True) + + try: + subprocess_utils.handle_returncode(returncode, + code, + 'Failed to fetch managed spot jobs', + job_table_payload + stderr, + stream_logs=False) + except exceptions.CommandError as e: + raise RuntimeError(str(e)) from e + + jobs = spot_utils.load_spot_job_queue(job_table_payload) + if skip_finished: + # Filter out the finished jobs. If a multi-task job is partially + # finished, we will include all its tasks. + non_finished_tasks = list( + filter(lambda job: not job['status'].is_terminal(), jobs)) + non_finished_job_ids = {job['job_id'] for job in non_finished_tasks} + jobs = list( + filter(lambda job: job['job_id'] in non_finished_job_ids, jobs)) + return jobs + + +@usage_lib.entrypoint +# pylint: disable=redefined-builtin +def cancel(name: Optional[str] = None, + job_ids: Optional[List[int]] = None, + all: bool = False) -> None: + # NOTE(dev): Keep the docstring consistent between the Python API and CLI. + """Cancel managed spot jobs. + + Please refer to the sky.cli.spot_cancel for the document. + + Raises: + sky.exceptions.ClusterNotUpError: the spot controller is not up. + RuntimeError: failed to cancel the job. + """ + job_ids = [] if job_ids is None else job_ids + handle = backend_utils.is_controller_accessible( + controller_type=controller_utils.Controllers.SPOT_CONTROLLER, + stopped_message='All managed spot jobs should have finished.') + + job_id_str = ','.join(map(str, job_ids)) + if sum([len(job_ids) > 0, name is not None, all]) != 1: + argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else '' + argument_str += f' name={name}' if name is not None else '' + argument_str += ' all' if all else '' + with ux_utils.print_exception_no_traceback(): + raise ValueError('Can only specify one of JOB_IDS or name or all. ' + f'Provided {argument_str!r}.') + + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend) + if all: + code = spot_utils.SpotCodeGen.cancel_jobs_by_id(None) + elif job_ids: + code = spot_utils.SpotCodeGen.cancel_jobs_by_id(job_ids) + else: + assert name is not None, (job_ids, name, all) + code = spot_utils.SpotCodeGen.cancel_job_by_name(name) + # The stderr is redirected to stdout + returncode, stdout, _ = backend.run_on_head(handle, + code, + require_outputs=True, + stream_logs=False) + try: + subprocess_utils.handle_returncode(returncode, code, + 'Failed to cancel managed spot job', + stdout) + except exceptions.CommandError as e: + with ux_utils.print_exception_no_traceback(): + raise RuntimeError(e.error_msg) from e + + sky_logging.print(stdout) + if 'Multiple jobs found with name' in stdout: + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + 'Please specify the job ID instead of the job name.') + + +@usage_lib.entrypoint +def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool) -> None: + # NOTE(dev): Keep the docstring consistent between the Python API and CLI. + """Tail logs of managed spot jobs. + + Please refer to the sky.cli.spot_logs for the document. + + Raises: + ValueError: invalid arguments. + sky.exceptions.ClusterNotUpError: the spot controller is not up. + """ + # TODO(zhwu): Automatically restart the spot controller + handle = backend_utils.is_controller_accessible( + controller_type=controller_utils.Controllers.SPOT_CONTROLLER, + stopped_message=('Please restart the spot controller with ' + f'`sky start {spot_utils.SPOT_CONTROLLER_NAME}`.')) + + if name is not None and job_id is not None: + raise ValueError('Cannot specify both name and job_id.') + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend), backend + # Stream the realtime logs + backend.tail_spot_logs(handle, job_id=job_id, job_name=name, follow=follow) + + +spot_launch = common_utils.deprecated_function(launch, + name='sky.spot.launch', + deprecated_name='spot_launch', + removing_version='0.7.0') +spot_queue = common_utils.deprecated_function(queue, + name='sky.spot.queue', + deprecated_name='spot_queue', + removing_version='0.7.0') +spot_cancel = common_utils.deprecated_function(cancel, + name='sky.spot.cancel', + deprecated_name='spot_cancel', + removing_version='0.7.0') +spot_tail_logs = common_utils.deprecated_function( + tail_logs, + name='sky.spot.tail_logs', + deprecated_name='spot_tail_logs', + removing_version='0.7.0') diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 58226209cec..c66b5dfe032 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -437,7 +437,7 @@ def rsync( backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) while max_retry >= 0: - returncode, _, stderr = log_lib.run_with_log( + returncode, stdout, stderr = log_lib.run_with_log( command, log_path=log_path, stream_logs=stream_logs, @@ -454,7 +454,7 @@ def rsync( subprocess_utils.handle_returncode(returncode, command, error_msg, - stderr=stderr, + stderr=stdout + stderr, stream_logs=stream_logs) def check_connection(self) -> bool: diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index a43cbf307ed..14a79a4d73f 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -630,3 +630,20 @@ def fill_template(template_name: str, variables: Dict, content = j2_template.render(**variables) with open(output_path, 'w', encoding='utf-8') as fout: fout.write(content) + + +def deprecated_function(func: Callable, name: str, deprecated_name: str, + removing_version: str) -> Callable: + """Decorator for creating deprecated functions, for backward compatibility. + + It will result in a warning being emitted when the function is used. + """ + + @functools.wraps(func) + def new_func(*args, **kwargs): + logger.warning( + f'Call to deprecated function {deprecated_name}, which will be ' + f'removed in {removing_version}. Please use {name}() instead.') + return func(*args, **kwargs) + + return new_func diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 09fac49d73f..b2eb58d3de3 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -32,7 +32,7 @@ logger = sky_logging.init_logger(__name__) -# Message thrown when APIs sky.spot_launch(),sky.serve.up() received an invalid +# Message thrown when APIs sky.spot.launch(),sky.serve.up() received an invalid # controller resources spec. CONTROLLER_RESOURCES_NOT_VALID_MESSAGE = ( '{controller_type} controller resources is not valid, please check ' diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 03ab3a72713..9803821d8bb 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -1,4 +1,5 @@ """Utilities for loading and dumping DAGs from/to YAML files.""" +import copy from typing import Any, Dict, List, Optional, Tuple from sky import dag as dag_lib @@ -11,6 +12,54 @@ logger = sky_logging.init_logger(__name__) +# Message thrown when APIs sky.{exec,launch,spot.launch}() received a string +# instead of a Dag. CLI (cli.py) is implemented by us so should not trigger +# this. +_ENTRYPOINT_STRING_AS_DAG_MESSAGE = """\ +Expected a sky.Task or sky.Dag but received a string. + +If you meant to run a command, make it a Task's run command: + + task = sky.Task(run=command) + +The command can then be run as: + + sky.exec(task, cluster_name=..., ...) + # Or use {'V100': 1}, 'V100:0.5', etc. + task.set_resources(sky.Resources(accelerators='V100:1')) + sky.exec(task, cluster_name=..., ...) + + sky.launch(task, ...) + + sky.spot.launch(task, ...) +""".strip() + + +def convert_entrypoint_to_dag(entrypoint: Any) -> 'dag_lib.Dag': + """Convert the entrypoint to a sky.Dag. + + Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. + """ + # Not suppressing stacktrace: when calling this via API user may want to + # see their own program in the stacktrace. Our CLI impl would not trigger + # these errors. + if isinstance(entrypoint, str): + with ux_utils.print_exception_no_traceback(): + raise TypeError(_ENTRYPOINT_STRING_AS_DAG_MESSAGE) + elif isinstance(entrypoint, dag_lib.Dag): + return copy.deepcopy(entrypoint) + elif isinstance(entrypoint, task_lib.Task): + entrypoint = copy.deepcopy(entrypoint) + with dag_lib.Dag() as dag: + dag.add(entrypoint) + dag.name = entrypoint.name + return dag + else: + with ux_utils.print_exception_no_traceback(): + raise TypeError( + 'Expected a sky.Task or sky.Dag but received argument of type: ' + f'{type(entrypoint)}') + def load_chain_dag_from_yaml( path: str, diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 073bd5cf743..424a0f5f9f7 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -170,7 +170,7 @@ rm -r ~/.sky/wheels || true s=$(sky spot logs --no-follow -n ${CLUSTER_NAME}-7-1) echo "$s" echo "$s" | grep " hi" || exit 1 -sky spot launch -d --cloud ${CLOUD} -y -n ${CLUSTER_NAME}-7-2 "echo hi; sleep 10" +sky spot launch -d --cloud ${CLOUD} -y -n ${CLUSTER_NAME}-7-2 "echo hi; sleep 60" s=$(sky spot logs --no-follow -n ${CLUSTER_NAME}-7-2) echo "$s" echo "$s" | grep " hi" || exit 1 @@ -178,7 +178,7 @@ s=$(sky spot queue | grep ${CLUSTER_NAME}-7) echo "$s" echo "$s" | grep "RUNNING" | wc -l | grep 3 || exit 1 sky spot cancel -y -n ${CLUSTER_NAME}-7-0 -sleep 200 +sky spot logs -n "${CLUSTER_NAME}-7-1" s=$(sky spot queue | grep ${CLUSTER_NAME}-7) echo "$s" echo "$s" | grep "SUCCEEDED" | wc -l | grep 2 || exit 1