Skip to content

Commit

Permalink
[Spot] Refactor spot APIs into spot.xxx (#3417)
Browse files Browse the repository at this point in the history
* Refactor spot core APIs to `sky.spot.core`

* Add comment

* fix

* format

* change to spot_lib instead

* Update sky/spot/core.py

Co-authored-by: Tian Xia <[email protected]>

* address comments

* Add deprecation message

* fix

* format

* minor fix for backward compat test

* longer time

* longer wait for spot backward test

* fix

---------

Co-authored-by: Tian Xia <[email protected]>
  • Loading branch information
Michaelvll and cblmemo authored Apr 11, 2024
1 parent 226c1eb commit e60eb73
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 362 deletions.
10 changes: 7 additions & 3 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.core import download_logs
from sky.core import job_status
from sky.core import queue
from sky.core import spot_cancel
from sky.core import spot_queue
from sky.core import start
from sky.core import status
from sky.core import stop
Expand All @@ -104,11 +102,16 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.data import StoreType
from sky.execution import exec # pylint: disable=redefined-builtin
from sky.execution import launch
from sky.execution import spot_launch
from sky.optimizer import Optimizer
from sky.optimizer import OptimizeTarget
from sky.resources import Resources
from sky.skylet.job_lib import JobStatus
# TODO (zhwu): These imports are for backward compatibility, and spot APIs
# should be called with `sky.spot.xxx` instead. Remove in release 0.7.0
from sky.spot.core import spot_cancel
from sky.spot.core import spot_launch
from sky.spot.core import spot_queue
from sky.spot.core import spot_tail_logs
from sky.status_lib import ClusterStatus
from sky.task import Task

Expand Down Expand Up @@ -173,6 +176,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
'queue',
'cancel',
'tail_logs',
'spot_tail_logs',
'download_logs',
'job_status',
# core APIs Spot Job Management
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,7 @@ def is_controller_accessible(
# will not start the controller manually from the cloud console.
#
# The acquire_lock_timeout is set to 0 to avoid hanging the command when
# multiple spot_launch commands are running at the same time. Our later
# multiple spot.launch commands are running at the same time. Our later
# code will check if the controller is accessible by directly checking
# the ssh connection to the controller, if it fails to get accurate
# status of the controller.
Expand Down
12 changes: 6 additions & 6 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,8 +1250,8 @@ def _get_spot_jobs(
usage_lib.messages.usage.set_internal()
with sky_logging.silent():
# Make the call silent
spot_jobs = core.spot_queue(refresh=refresh,
skip_finished=skip_finished)
spot_jobs = spot_lib.queue(refresh=refresh,
skip_finished=skip_finished)
num_in_progress_jobs = len(spot_jobs)
except exceptions.ClusterNotUpError as e:
controller_status = e.cluster_status
Expand Down Expand Up @@ -2508,7 +2508,7 @@ def _hint_or_raise_for_down_spot_controller(controller_name: str):
with rich_utils.safe_status(
'[bold cyan]Checking for in-progress spot jobs[/]'):
try:
spot_jobs = core.spot_queue(refresh=False, skip_finished=True)
spot_jobs = spot_lib.queue(refresh=False, skip_finished=True)
except exceptions.ClusterNotUpError as e:
if controller.value.connection_error_hint in str(e):
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -3289,7 +3289,7 @@ def spot_launch(

common_utils.check_cluster_name_is_valid(name)

sky.spot_launch(dag,
spot_lib.launch(dag,
name,
detach_run=detach_run,
retry_until_up=retry_until_up)
Expand Down Expand Up @@ -3448,7 +3448,7 @@ def spot_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
abort=True,
show_default=True)

core.spot_cancel(job_ids=job_ids, name=name, all=all)
spot_lib.cancel(job_ids=job_ids, name=name, all=all)


@spot.command('logs', cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -3480,7 +3480,7 @@ def spot_logs(name: Optional[str], job_id: Optional[int], follow: bool,
job_id=job_id,
follow=follow)
else:
core.spot_tail_logs(name=name, job_id=job_id, follow=follow)
spot_lib.tail_logs(name=name, job_id=job_id, follow=follow)
except exceptions.ClusterNotUpError as e:
click.echo(e)
sys.exit(1)
Expand Down
187 changes: 6 additions & 181 deletions sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@
from sky import exceptions
from sky import global_user_state
from sky import sky_logging
from sky import spot
from sky import status_lib
from sky import task
from sky.backends import backend_utils
from sky.skylet import constants
from sky.skylet import job_lib
from sky.usage import usage_lib
from sky.utils import controller_utils
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky import resources as resources_lib
Expand Down Expand Up @@ -230,7 +227,7 @@ def start(
retry_until_up: bool = False,
down: bool = False, # pylint: disable=redefined-outer-name
force: bool = False,
) -> None:
) -> backends.CloudVmRayResourceHandle:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Restart a cluster.
Expand Down Expand Up @@ -279,11 +276,11 @@ def start(
if down and idle_minutes_to_autostop is None:
raise ValueError(
'`idle_minutes_to_autostop` must be set if `down` is True.')
_start(cluster_name,
idle_minutes_to_autostop,
retry_until_up,
down,
force=force)
return _start(cluster_name,
idle_minutes_to_autostop,
retry_until_up,
down,
force=force)


def _stop_not_supported_message(resources: 'resources_lib.Resources') -> str:
Expand Down Expand Up @@ -762,178 +759,6 @@ def job_status(cluster_name: str,
return statuses


# =======================
# = Spot Job Management =
# =======================


@usage_lib.entrypoint
def spot_queue(refresh: bool,
skip_finished: bool = False) -> List[Dict[str, Any]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Get statuses of managed spot jobs.
Please refer to the sky.cli.spot_queue for the documentation.
Returns:
[
{
'job_id': int,
'job_name': str,
'resources': str,
'submitted_at': (float) timestamp of submission,
'end_at': (float) timestamp of end,
'duration': (float) duration in seconds,
'recovery_count': (int) Number of retries,
'status': (sky.spot.SpotStatus) of the job,
'cluster_resources': (str) resources of the cluster,
'region': (str) region of the cluster,
}
]
Raises:
sky.exceptions.ClusterNotUpError: the spot controller is not up or
does not exist.
RuntimeError: if failed to get the spot jobs with ssh.
"""
stopped_message = ''
if not refresh:
stopped_message = ('No in-progress spot jobs.')
try:
handle = backend_utils.is_controller_accessible(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=stopped_message)
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if refresh and handle is None:
sky_logging.print(f'{colorama.Fore.YELLOW}'
'Restarting controller for latest status...'
f'{colorama.Style.RESET_ALL}')

rich_utils.force_update_status('[cyan] Checking spot jobs - restarting '
'controller[/]')
handle = _start(spot.SPOT_CONTROLLER_NAME)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status('[cyan] Checking spot jobs[/]')

assert handle is not None, (controller_status, refresh)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

code = spot.SpotCodeGen.get_job_table()
returncode, job_table_payload, stderr = backend.run_on_head(
handle,
code,
require_outputs=True,
stream_logs=False,
separate_stderr=True)

try:
subprocess_utils.handle_returncode(returncode,
code,
'Failed to fetch managed spot jobs',
job_table_payload + stderr,
stream_logs=False)
except exceptions.CommandError as e:
raise RuntimeError(str(e)) from e

jobs = spot.load_spot_job_queue(job_table_payload)
if skip_finished:
# Filter out the finished jobs. If a multi-task job is partially
# finished, we will include all its tasks.
non_finished_tasks = list(
filter(lambda job: not job['status'].is_terminal(), jobs))
non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
jobs = list(
filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
return jobs


@usage_lib.entrypoint
# pylint: disable=redefined-builtin
def spot_cancel(name: Optional[str] = None,
job_ids: Optional[List[int]] = None,
all: bool = False) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Cancel managed spot jobs.
Please refer to the sky.cli.spot_cancel for the document.
Raises:
sky.exceptions.ClusterNotUpError: the spot controller is not up.
RuntimeError: failed to cancel the job.
"""
job_ids = [] if job_ids is None else job_ids
handle = backend_utils.is_controller_accessible(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message='All managed spot jobs should have finished.')

job_id_str = ','.join(map(str, job_ids))
if sum([len(job_ids) > 0, name is not None, all]) != 1:
argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else ''
argument_str += f' name={name}' if name is not None else ''
argument_str += ' all' if all else ''
raise ValueError('Can only specify one of JOB_IDS or name or all. '
f'Provided {argument_str!r}.')

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)
if all:
code = spot.SpotCodeGen.cancel_jobs_by_id(None)
elif job_ids:
code = spot.SpotCodeGen.cancel_jobs_by_id(job_ids)
else:
assert name is not None, (job_ids, name, all)
code = spot.SpotCodeGen.cancel_job_by_name(name)
# The stderr is redirected to stdout
returncode, stdout, _ = backend.run_on_head(handle,
code,
require_outputs=True,
stream_logs=False)
try:
subprocess_utils.handle_returncode(returncode, code,
'Failed to cancel managed spot job',
stdout)
except exceptions.CommandError as e:
raise RuntimeError(e.error_msg) from e

sky_logging.print(stdout)
if 'Multiple jobs found with name' in stdout:
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
'Please specify the job ID instead of the job name.')


@usage_lib.entrypoint
def spot_tail_logs(name: Optional[str], job_id: Optional[int],
follow: bool) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tail logs of managed spot jobs.
Please refer to the sky.cli.spot_logs for the document.
Raises:
ValueError: invalid arguments.
sky.exceptions.ClusterNotUpError: the spot controller is not up.
"""
# TODO(zhwu): Automatically restart the spot controller
handle = backend_utils.is_controller_accessible(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=('Please restart the spot controller with '
f'`sky start {spot.SPOT_CONTROLLER_NAME}`.'))

if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend
# Stream the realtime logs
backend.tail_spot_logs(handle, job_id=job_id, job_name=name, follow=follow)


# ======================
# = Storage Management =
# ======================
Expand Down
Loading

0 comments on commit e60eb73

Please sign in to comment.