diff --git a/changes.d/5698.feat.md b/changes.d/5698.feat.md new file mode 100644 index 00000000000..38c6ad65a3d --- /dev/null +++ b/changes.d/5698.feat.md @@ -0,0 +1 @@ +Flow-specific task hold and release. diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py index becd21f50e9..081df5a61fb 100644 --- a/cylc/flow/data_store_mgr.py +++ b/cylc/flow/data_store_mgr.py @@ -1012,10 +1012,7 @@ def generate_ghost_task( id=tp_id, task=t_id, cycle_point=point_string, - is_held=( - (name, point) - in self.schd.pool.tasks_to_hold - ), + is_held=self.schd.pool.hold_mgr.is_held(name, point), depth=task_def.depth, name=name, ) diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py index 148adb8213f..d13e745b46d 100644 --- a/cylc/flow/flow_mgr.py +++ b/cylc/flow/flow_mgr.py @@ -20,6 +20,7 @@ import datetime from cylc.flow import LOG +from cylc.flow.exceptions import InputError from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager @@ -30,6 +31,15 @@ FLOW_NONE = "none" +def validate_flow_opt(val): + """Validate command line --flow opions.""" + if val is not None: + try: + int(val) + except ValueError: + raise InputError(f"--flow={val}: value must be integer.") + + class FlowMgr: """Logic to manage flow counter and flow metadata.""" diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index 50b92c99445..0b20ff7033c 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -1735,7 +1735,9 @@ class Arguments: description='Hold all tasks after the specified cycle point.', required=True ) - + flow_num = Int( + description='Number of flow to hold.' + ) result = GenericScalar() @@ -2035,6 +2037,11 @@ class Meta: ''') resolver = mutator + class Arguments(TaskMutation.Arguments): + flow_num = Int( + description='Number of flow to hold.' + ) + class Release(Mutation, TaskMutation): class Meta: @@ -2047,6 +2054,11 @@ class Meta: ''') resolver = mutator + class Arguments(TaskMutation.Arguments): + flow_num = Int( + description='Number of flow to release.' + ) + class Kill(Mutation, TaskMutation): # TODO: This should be a job mutation? diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py index f7632b8997e..27a2be2535b 100644 --- a/cylc/flow/rundb.py +++ b/cylc/flow/rundb.py @@ -319,6 +319,7 @@ class CylcWorkflowDAO: TABLE_TASKS_TO_HOLD: [ ["name"], ["cycle"], + ["flow"], ], } @@ -939,11 +940,11 @@ def select_task_prerequisites( stmt_args = [cycle, name, flow_nums] return list(self.connect().execute(stmt, stmt_args)) - def select_tasks_to_hold(self) -> List[Tuple[str, str]]: + def select_tasks_to_hold(self) -> List[Tuple[str, str, str]]: """Return all tasks to hold stored in the DB.""" stmt = rf''' SELECT - name, cycle + name, cycle, flow FROM {self.TABLE_TASKS_TO_HOLD} ''' # nosec (table name is code constant) diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index 7449e7d274a..f43ace798c6 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -905,7 +905,8 @@ def get_command_method(self, command_name: str) -> Callable: def queue_command(self, command: str, kwargs: dict) -> None: self.command_queue.put(( command, - tuple(kwargs.values()), {} + (), + kwargs, )) async def process_command_queue(self) -> None: @@ -1011,9 +1012,13 @@ def _set_stop(self, stop_mode: Optional[StopMode] = None) -> None: self.stop_mode = stop_mode self.update_data_store() - def command_release(self, task_globs: Iterable[str]) -> int: + def command_release( + self, + tasks: Iterable[str], + flow_num: Optional[int] = None + ) -> int: """Release held tasks.""" - return self.pool.release_held_tasks(task_globs) + return self.pool.release_held_tasks(tasks, flow_num) def command_release_hold_point(self) -> None: """Release all held tasks and unset workflow hold after cycle point, @@ -1025,31 +1030,42 @@ def command_resume(self) -> None: """Resume paused workflow.""" self.resume_workflow() - def command_poll_tasks(self, items: List[str]) -> int: + def command_poll_tasks(self, tasks: List[str]) -> int: """Poll pollable tasks or a task or family if options are provided.""" if self.config.run_mode('simulation'): return 0 - itasks, _, bad_items = self.pool.filter_task_proxies(items) + itasks, _, bad_items = self.pool.filter_task_proxies(tasks) self.task_job_mgr.poll_task_jobs(self.workflow, itasks) return len(bad_items) - def command_kill_tasks(self, items: List[str]) -> int: + def command_kill_tasks(self, tasks: List[str]) -> int: """Kill all tasks or a task/family if options are provided.""" - itasks, _, bad_items = self.pool.filter_task_proxies(items) + itasks, _, bad_items = self.pool.filter_task_proxies(tasks) if self.config.run_mode('simulation'): for itask in itasks: if itask.state(*TASK_STATUSES_ACTIVE): itask.state_reset(TASK_STATUS_FAILED) self.data_store_mgr.delta_task_state(itask) return len(bad_items) - self.task_job_mgr.kill_task_jobs(self.workflow, itasks) + to_kill = self.task_job_mgr.kill_task_jobs(self.workflow, itasks) + # Hold killed tasks to prevent automatic retry. + for itask in to_kill: + self.pool.hold_mgr.hold_active_task(itask) return len(bad_items) - def command_hold(self, task_globs: Iterable[str]) -> int: + def command_hold( + self, + tasks: Iterable[str], + flow_num: Optional[int] = None + ) -> int: """Hold specified tasks.""" - return self.pool.hold_tasks(task_globs) + return self.pool.hold_tasks(tasks, flow_num) - def command_set_hold_point(self, point: str) -> None: + def command_set_hold_point( + self, + point: str, + flow_num: Optional[int] = None + ) -> None: """Hold all tasks after the specified cycle point.""" cycle_point = TaskID.get_standardised_point(point) if cycle_point is None: @@ -1057,25 +1073,25 @@ def command_set_hold_point(self, point: str) -> None: LOG.info( f"Setting hold cycle point: {cycle_point}\n" "All tasks after this point will be held.") - self.pool.set_hold_point(cycle_point) + self.pool.set_hold_point(cycle_point, flow_num) def command_pause(self) -> None: """Pause the workflow.""" self.pause_workflow() @staticmethod - def command_set_verbosity(lvl: Union[int, str]) -> None: + def command_set_verbosity(level: Union[int, str]) -> None: """Set workflow verbosity.""" try: - lvl = int(lvl) - LOG.setLevel(lvl) + level = int(level) + LOG.setLevel(level) except (TypeError, ValueError) as exc: raise CommandFailedError(exc) - cylc.flow.flags.verbosity = log_level_to_verbosity(lvl) + cylc.flow.flags.verbosity = log_level_to_verbosity(level) - def command_remove_tasks(self, items) -> int: + def command_remove_tasks(self, tasks) -> int: """Remove tasks.""" - return self.pool.remove_tasks(items) + return self.pool.remove_tasks(tasks) async def command_reload_workflow(self) -> None: """Reload workflow configuration.""" @@ -1329,6 +1345,8 @@ def _set_workflow_params( * Original workflow run time zone. """ LOG.info('LOADING workflow parameters') + self.options.holdcp_flow = None # (not CLI but needed on restart) + for key, value in params: if value is None: continue @@ -1370,6 +1388,12 @@ def _set_workflow_params( ): self.options.holdcp = value LOG.info(f"+ hold point = {value}") + elif ( + key == self.workflow_db_mgr.KEY_HOLD_CYCLE_POINT_FLOW + and self.options.holdcp_flow is None + ): + self.options.holdcp_flow = value + LOG.info(f"+ hold point flow = {value}") elif key == self.workflow_db_mgr.KEY_STOP_CLOCK_TIME: int_val = int(value) msg = f"stop clock time = {int_val} ({time2str(int_val)})" diff --git a/cylc/flow/scripts/hold.py b/cylc/flow/scripts/hold.py index dddbd1a61af..9c700ab050b 100755 --- a/cylc/flow/scripts/hold.py +++ b/cylc/flow/scripts/hold.py @@ -59,6 +59,7 @@ from typing import TYPE_CHECKING from cylc.flow.exceptions import InputError +from cylc.flow.flow_mgr import validate_flow_opt from cylc.flow.network.client_factory import get_client from cylc.flow.option_parsers import ( FULL_ID_MULTI_ARG_DOC, @@ -67,6 +68,7 @@ from cylc.flow.terminal import cli_function from cylc.flow.network.multi import call_multi + if TYPE_CHECKING: from optparse import Values @@ -74,11 +76,13 @@ HOLD_MUTATION = ''' mutation ( $wFlows: [WorkflowID]!, - $tasks: [NamespaceIDGlob]! + $tasks: [NamespaceIDGlob]!, + $flowNum: Int ) { hold ( workflows: $wFlows, - tasks: $tasks + tasks: $tasks, + flowNum: $flowNum ) { result } @@ -88,11 +92,13 @@ SET_HOLD_POINT_MUTATION = ''' mutation ( $wFlows: [WorkflowID]!, - $point: CyclePoint! + $point: CyclePoint!, + $flowNum: Int ) { setHoldPoint ( workflows: $wFlows, - point: $point + point: $point, + flowNum: $flowNum ) { result } @@ -114,6 +120,11 @@ def get_option_parser() -> COP: help="Hold all tasks after this cycle point.", metavar="CYCLE_POINT", action="store", dest="hold_point_string") + parser.add_option( + "--flow", + help="Hold tasks that belong to a specific flow.", + metavar="INT", action="store", dest="flow_num") + return parser @@ -123,12 +134,14 @@ def _validate(options: 'Values', *task_globs: str) -> None: if task_globs: raise InputError( "Cannot combine --after with Cylc/Task IDs.\n" - "`cylc hold --after` holds all tasks after the given " - "cycle point.") + "`cylc hold --after` holds ALL tasks after the given " + "cycle point. Can be used with `--flow`.") elif not task_globs: raise InputError( "Must define Cycles/Tasks. See `cylc hold --help`.") + validate_flow_opt(options.flow_num) + async def run(options, workflow_id, *tokens_list): _validate(options, *tokens_list) @@ -137,14 +150,18 @@ async def run(options, workflow_id, *tokens_list): if options.hold_point_string: mutation = SET_HOLD_POINT_MUTATION - args = {'point': options.hold_point_string} + args = { + 'point': options.hold_point_string, + 'flowNum': options.flow_num + } else: mutation = HOLD_MUTATION args = { 'tasks': [ id_.relative_id_with_selectors for id_ in tokens_list - ] + ], + 'flowNum': options.flow_num } mutation_kwargs = { diff --git a/cylc/flow/scripts/release.py b/cylc/flow/scripts/release.py index 1fe268ac56e..2740a998e53 100755 --- a/cylc/flow/scripts/release.py +++ b/cylc/flow/scripts/release.py @@ -42,6 +42,7 @@ from typing import TYPE_CHECKING from cylc.flow.exceptions import InputError +from cylc.flow.flow_mgr import validate_flow_opt from cylc.flow.network.client_factory import get_client from cylc.flow.network.multi import call_multi from cylc.flow.option_parsers import ( @@ -57,11 +58,13 @@ RELEASE_MUTATION = ''' mutation ( $wFlows: [WorkflowID]!, - $tasks: [NamespaceIDGlob]! + $tasks: [NamespaceIDGlob]!, + $flowNum: Int ) { release ( workflows: $wFlows, tasks: $tasks, + flowNum: $flowNum ) { result } @@ -97,6 +100,11 @@ def get_option_parser() -> COP: "if set."), action="store_true", dest="release_all") + parser.add_option( + "--flow", + help="Release tasks that belong to a specific flow.", + metavar="INT", action="store", dest="flow_num") + return parser @@ -111,6 +119,8 @@ def _validate(options: 'Values', *tokens_list: str) -> None: "Must define Cycles/Tasks. See `cylc release --help`." ) + validate_flow_opt(options.flow_num) + async def run(options: 'Values', workflow_id, *tokens_list): _validate(options, *tokens_list) @@ -126,7 +136,8 @@ async def run(options: 'Values', workflow_id, *tokens_list): 'tasks': [ tokens.relative_id_with_selectors for tokens in tokens_list - ] + ], + 'flowNum': options.flow_num } mutation_kwargs = { diff --git a/cylc/flow/scripts/set_outputs.py b/cylc/flow/scripts/set_outputs.py index 1996796564a..e400a278be8 100755 --- a/cylc/flow/scripts/set_outputs.py +++ b/cylc/flow/scripts/set_outputs.py @@ -50,6 +50,7 @@ from functools import partial from optparse import Values +from cylc.flow.flow_mgr import validate_flow_opt from cylc.flow.network.client_factory import get_client from cylc.flow.network.multi import call_multi from cylc.flow.option_parsers import ( @@ -58,6 +59,7 @@ ) from cylc.flow.terminal import cli_function + MUTATION = ''' mutation ( $wFlows: [WorkflowID]!, @@ -92,14 +94,21 @@ def get_option_parser() -> COP: action="append", default=None, dest="outputs") parser.add_option( - "-f", "--flow", metavar="FLOW", + "-f", "--flow", metavar="INT", help="Number of the flow to attribute the outputs.", action="store", default=None, dest="flow_num") return parser +def _validate(options: 'Values', *task_globs: str) -> None: + """Check combination of options and task globs is valid.""" + validate_flow_opt(options.flow_num) + + async def run(options: 'Values', workflow_id: str, *tokens_list) -> None: + + _validate(options, *tokens_list) pclient = get_client(workflow_id, timeout=options.comms_timeout) mutation_kwargs = { diff --git a/cylc/flow/scripts/trigger.py b/cylc/flow/scripts/trigger.py index 3e4a3da96f7..1a1fe30ed9f 100755 --- a/cylc/flow/scripts/trigger.py +++ b/cylc/flow/scripts/trigger.py @@ -100,7 +100,8 @@ def get_option_parser() -> COP: "--flow", action="append", dest="flow", metavar="FLOW", help=f"Assign the triggered task to all active flows ({FLOW_ALL});" f" no flow ({FLOW_NONE}); a new flow ({FLOW_NEW});" - f" or a specific flow (e.g. 2). The default is {FLOW_ALL}." + " or a specific integer flow (e.g. 2). The default is" + f" {FLOW_ALL}." " Reuse the option to assign multiple specific flows." ) diff --git a/cylc/flow/task_job_mgr.py b/cylc/flow/task_job_mgr.py index 20ee7379d27..3dde2dcf388 100644 --- a/cylc/flow/task_job_mgr.py +++ b/cylc/flow/task_job_mgr.py @@ -176,7 +176,7 @@ def check_task_jobs(self, workflow, task_pool): self.poll_task_jobs(workflow, poll_tasks) def kill_task_jobs(self, workflow, itasks): - """Kill jobs of active tasks, and hold the tasks. + """Kill jobs of active tasks. If items is specified, kill active tasks matching given IDs. @@ -184,8 +184,6 @@ def kill_task_jobs(self, workflow, itasks): to_kill_tasks = [] for itask in itasks: if itask.state(*TASK_STATUSES_ACTIVE): - itask.state_reset(is_held=True) - self.data_store_mgr.delta_task_held(itask) to_kill_tasks.append(itask) else: LOG.warning(f"[{itask}] not killable") @@ -194,6 +192,7 @@ def kill_task_jobs(self, workflow, itasks): self._kill_task_jobs_callback, self._kill_task_jobs_callback_255 ) + return to_kill_tasks def poll_task_jobs(self, workflow, itasks, msg=None): """Poll jobs of specified tasks. diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index e7d85f66938..42edfddccb2 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -29,6 +29,7 @@ TYPE_CHECKING, Tuple, Union, + Callable, ) import logging @@ -85,6 +86,154 @@ Pool = Dict['PointBase', Dict[str, TaskProxy]] +class TaskHoldMgr: + """Hold/release logic for active and future tasks. + + Active tasks (i.e., task proxies in the pool): + - hold/release with --flow=n, or (by default) regardless of flow. + + Future tasks (point/name): + - flagg for future hold, or unflag, with --flow=n, or (by default) + regardless of flow. + + Note this class doesn't yet handle workflow hold point. + + """ + def __init__( + self, + workflow_db_mgr: 'WorkflowDatabaseManager', + data_store_mgr: 'DataStoreMgr', + ): + # (name, point): flow + self.hold: Dict[Tuple[str, 'PointBase'], Optional[int]] = {} + # flow may be None: hold future task regardless of its flow number. + # NOTE: RHS could be a set of flow numbers, meaning future-hold same + # task in multiple specific flows. But those instances can't coexist in + # the pool so serially holding them is probably fine. + self.data_store_mgr = data_store_mgr + self.db_mgr = workflow_db_mgr + + def _flatten(self): + # possibly-temporary conversion to old-style flat set + result = set() + for (name, point), flow in self.hold.items(): + result.add((name, point, flow)) + return result + + def _update_stores( + self, + itask: Union[TaskProxy, Tuple[str, 'PointBase', bool]] + ): + """Update datastore and database.""" + self.db_mgr.put_tasks_to_hold(self._flatten()) + self.data_store_mgr.delta_task_held(itask) + LOG.debug(f"Tasks to hold {self.hold}") + + def load_from_db(self): + """Load the store of tasks-to-hold from the run DB.""" + # Note this doesn't need to actually hold the tasks - they're + # automatically held at creation via their is_held attribute. + for name, cycle, flow_num in ( + self.db_mgr.pri_dao.select_tasks_to_hold() + ): + self.hold[(name, get_point(cycle))] = flow_num + + def hold_active_task( + self, + itask: TaskProxy, + flow_num: Optional[int] = None, + ) -> bool: + """Hold itask if the specified flow_num matches or is None.""" + if flow_num is not None and flow_num not in itask.flow_nums: + # specified flow does not match this task + return False + if not itask.state_reset(is_held=True): + # already held + return False + self.hold[(itask.tdef.name, itask.point)] = flow_num + self._update_stores(itask) + return True + + def flag_future_task( + self, + name: str, + point: 'PointBase', + flow_num: Optional[int] = None + ) -> None: + """Flag that we should hold a future task.""" + self.hold[(name, point)] = flow_num + self._update_stores((name, point, True)) + + def hold_if_flagged( + self, + itask: TaskProxy + ) -> None: + """Hold a newly-spawned task if flagged in the future hold list. + + flow None: a future-held specific task regardless of flow. + """ + if (itask.tdef.name, itask.point) not in self.hold.keys(): + return + + if ( + self.hold[(itask.tdef.name, itask.point)] is None + or self.hold[(itask.tdef.name, itask.point)] in itask.flow_nums + ): + LOG.info(f"[{itask}] holding (as requested earlier)") + itask.state_reset(is_held=True) + + def release_future_task( + self, + name: str, + point: 'PointBase', + flow_num: Optional[int] = None + ) -> None: + """Un-flag point/name if flow matches or flow is None.""" + if (name, point) not in self.hold.keys(): + return + if ( + flow_num is None + or self.hold[(name, point)] is None + or flow_num == self.hold[(name, point)] + ): + del self.hold[(name, point)] + + self._update_stores((name, point, False)) + + def release_active_task( + self, + itask: TaskProxy, + queue_func: Callable, + flow_num: Optional[int] = None, + ) -> None: + """Release a held task if flow matches, and queue it if ready.""" + if ( + flow_num is not None + and flow_num not in itask.flow_nums + ): + return + + if not itask.state_reset(is_held=False): + # not held + return + + del self.hold[(itask.tdef.name, itask.point)] + self._update_stores(itask) + if ( + not itask.state.is_runahead + and all(itask.is_ready_to_run()) + ): + queue_func(itask) + + def is_held( + self, + name: str, + point: 'PointBase', + ) -> bool: + """Is point/name held, regardless of flow.""" + return (name, point) in self.hold + + class TaskPool: """Task pool of a workflow.""" @@ -125,6 +274,7 @@ def __init__( self.tasks_removed = False self.hold_point: Optional['PointBase'] = None + self.hold_point_flow: Optional[int] = None self.abs_outputs_done: Set[Tuple[str, str, str]] = set() self.stop_task_id: Optional[str] = None @@ -138,7 +288,10 @@ def __init__( self.task_name_list, self.config.runtime['descendants'] ) - self.tasks_to_hold: Set[Tuple[str, 'PointBase']] = set() + self.hold_mgr: 'TaskHoldMgr' = TaskHoldMgr( + self.workflow_db_mgr, + self.data_store_mgr + ) def set_stop_task(self, task_id): """Set stop after a task.""" @@ -672,10 +825,7 @@ def load_db_task_action_timers(self, row_idx, row): def load_db_tasks_to_hold(self): """Update the tasks_to_hold set with the tasks stored in the database.""" - self.tasks_to_hold.update( - (name, get_point(cycle)) for name, cycle in - self.workflow_db_mgr.pri_dao.select_tasks_to_hold() - ) + self.hold_mgr.load_from_db() def rh_release_and_queue(self, itask) -> None: """Release a task from runahead limiting, and queue it if ready. @@ -1169,71 +1319,84 @@ def is_stalled(self) -> bool: unsatisfied = self.log_unsatisfied_prereqs() return (incomplete or unsatisfied) - def hold_active_task(self, itask: TaskProxy) -> None: - if itask.state_reset(is_held=True): - self.data_store_mgr.delta_task_held(itask) - self.tasks_to_hold.add((itask.tdef.name, itask.point)) - self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) + def set_hold_point( + self, + point: 'PointBase', + flow_num: Optional[int] = None + ) -> None: + """Set the point after which all tasks must be held. - def release_held_active_task(self, itask: TaskProxy) -> None: - if itask.state_reset(is_held=False): - self.data_store_mgr.delta_task_held(itask) - if (not itask.state.is_runahead) and all(itask.is_ready_to_run()): - self.queue_task(itask) - self.tasks_to_hold.discard((itask.tdef.name, itask.point)) - self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) + This can be flow-specific (hold flow n after point), but there is only + one hold point so flow-specific hold-point release is not needed. + TODO: extend to allow multiple flow-specific hold points. - def set_hold_point(self, point: 'PointBase') -> None: - """Set the point after which all tasks must be held.""" + """ self.hold_point = point + self.hold_point_flow = flow_num for itask in self.get_all_tasks(): if itask.point > point: - self.hold_active_task(itask) - self.workflow_db_mgr.put_workflow_hold_cycle_point(point) + self.hold_mgr.hold_active_task(itask, flow_num) + self.workflow_db_mgr.put_workflow_hold_cycle_point(point, flow_num) - def hold_tasks(self, items: Iterable[str]) -> int: + def hold_tasks( + self, + items: Iterable[str], + flow_num: Optional[int] = None + ) -> int: """Hold tasks with IDs matching the specified items.""" - # Hold active tasks: itasks, future_tasks, unmatched = self.filter_task_proxies( items, warn=False, future=True, ) + # Hold active tasks: for itask in itasks: - self.hold_active_task(itask) + self.hold_mgr.hold_active_task(itask, flow_num) + # Set future tasks to be held: for name, cycle in future_tasks: - self.data_store_mgr.delta_task_held((name, cycle, True)) - self.tasks_to_hold.update(future_tasks) - self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) - LOG.debug(f"Tasks to hold: {self.tasks_to_hold}") + self.hold_mgr.flag_future_task(name, cycle, flow_num) + return len(unmatched) - def release_held_tasks(self, items: Iterable[str]) -> int: + def release_held_tasks( + self, + items: Iterable[str], + flow_num: Optional[int] = None + ) -> int: """Release held tasks with IDs matching any specified items.""" - # Release active tasks: itasks, future_tasks, unmatched = self.filter_task_proxies( items, warn=False, future=True, ) + # Release active tasks: for itask in itasks: - self.release_held_active_task(itask) + if not itask.state(is_held=True): + continue + if flow_num is None or flow_num in itask.flow_nums: + self.hold_mgr.release_active_task( + itask, self.queue_task, flow_num) + # Unhold future tasks: for name, cycle in future_tasks: - self.data_store_mgr.delta_task_held((name, cycle, False)) - self.tasks_to_hold.difference_update(future_tasks) - self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) - LOG.debug(f"Tasks to hold: {self.tasks_to_hold}") + self.hold_mgr.release_future_task(name, cycle, flow_num) + return len(unmatched) def release_hold_point(self) -> None: - """Unset the workflow hold point and release all held active tasks.""" + """Release ALL held active tasks and unset the hold-after point. + + Note the CLI does not currently have an option to just release tasks + after the hold point (there could be held tasks before the hold point). + + NOTE: the hold-point can be flow-specific, but there is currently only + one hold-point so hold-point release need not be flow-specific yet. + + """ self.hold_point = None for itask in self.get_all_tasks(): - self.release_held_active_task(itask) - self.tasks_to_hold.clear() - self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) + self.hold_mgr.release_active_task(itask, self.queue_task) self.workflow_db_mgr.put_workflow_hold_cycle_point(None) def check_abort_on_task_fails(self): @@ -1529,16 +1692,22 @@ def spawn_task( is_manual_submit=is_manual_submit, flow_wait=flow_wait, ) - if (name, point) in self.tasks_to_hold: - LOG.info(f"[{itask}] holding (as requested earlier)") - self.hold_active_task(itask) - elif self.hold_point and itask.point > self.hold_point: - # Hold if beyond the workflow hold point + if ( + self.hold_point and itask.point > self.hold_point + and ( + self.hold_point_flow is None + or self.hold_point_flow in itask.flow_nums + ) + ): + # Hold new task if beyond hold point LOG.info( f"[{itask}] holding (beyond workflow " f"hold point: {self.hold_point})" ) - self.hold_active_task(itask) + self.hold_mgr.hold_active_task(itask) + else: + # Or if in the future hold list. + self.hold_mgr.hold_if_flagged(itask) if self.stop_point and itask.point <= self.stop_point: future_trigger_overrun = False diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index 091d1712429..6f7a7a45254 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -77,6 +77,7 @@ class WorkflowDatabaseManager: KEY_UTC_MODE = 'UTC_mode' KEY_PAUSED = 'is_paused' KEY_HOLD_CYCLE_POINT = 'holdcp' + KEY_HOLD_CYCLE_POINT_FLOW = 'holdcp_flow' KEY_RUN_MODE = 'run_mode' KEY_STOP_CLOCK_TIME = 'stop_clock_time' KEY_STOP_TASK = 'stop_task' @@ -352,13 +353,18 @@ def put_workflow_paused(self, value: bool) -> None: self.put_workflow_params_1(self.KEY_PAUSED, int(value)) def put_workflow_hold_cycle_point( - self, value: Optional['PointBase'] + self, + value: Optional['PointBase'], + flow_num: Optional[int] = None ) -> None: """Put workflow hold cycle point to workflow_params table.""" self.put_workflow_params_1( self.KEY_HOLD_CYCLE_POINT, str(value) if value is not None else None ) + self.put_workflow_params_1( + self.KEY_HOLD_CYCLE_POINT_FLOW, flow_num + ) def put_workflow_stop_clock_time(self, value: Optional[str]) -> None: """Put workflow stop clock time to workflow_params table.""" @@ -519,7 +525,7 @@ def put_task_pool(self, pool: 'TaskPool') -> None: itask.state.time_updated = None def put_tasks_to_hold( - self, tasks: Set[Tuple[str, 'PointBase']] + self, tasks: Set[Tuple[str, 'PointBase', Optional[int]]] ) -> None: """Replace the tasks in the tasks_to_hold table.""" # There isn't that much cost in calling this multiple times between @@ -528,8 +534,8 @@ def put_tasks_to_hold( # whole table each time the queue is processed is a bit inefficient. self.db_deletes_map[self.TABLE_TASKS_TO_HOLD] = [{}] self.db_inserts_map[self.TABLE_TASKS_TO_HOLD] = [ - {"name": name, "cycle": str(point)} - for name, point in tasks + {"name": name, "cycle": str(point), "flow": flow_num} + for name, point, flow_num in tasks ] def put_insert_task_events(self, itask, args): diff --git a/tests/flakyfunctional/database/00-simple/schema.out b/tests/flakyfunctional/database/00-simple/schema.out index 8ed66a1f2db..ca8361fe205 100644 --- a/tests/flakyfunctional/database/00-simple/schema.out +++ b/tests/flakyfunctional/database/00-simple/schema.out @@ -12,7 +12,7 @@ CREATE TABLE task_pool(cycle TEXT, name TEXT, flow_nums TEXT, status TEXT, is_he CREATE TABLE task_prerequisites(cycle TEXT, name TEXT, flow_nums TEXT, prereq_name TEXT, prereq_cycle TEXT, prereq_output TEXT, satisfied TEXT, PRIMARY KEY(cycle, name, flow_nums, prereq_name, prereq_cycle, prereq_output)); CREATE TABLE task_states(name TEXT, cycle TEXT, flow_nums TEXT, time_created TEXT, time_updated TEXT, submit_num INTEGER, status TEXT, flow_wait INTEGER, is_manual_submit INTEGER, PRIMARY KEY(name, cycle, flow_nums)); CREATE TABLE task_timeout_timers(cycle TEXT, name TEXT, timeout REAL, PRIMARY KEY(cycle, name)); -CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT); +CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT, flow TEXT); CREATE TABLE workflow_flows(flow_num INTEGER, start_time TEXT, description TEXT, PRIMARY KEY(flow_num)); CREATE TABLE xtriggers(signature TEXT, results TEXT, PRIMARY KEY(signature)); CREATE TABLE absolute_outputs(cycle TEXT, name TEXT, output TEXT); diff --git a/tests/functional/database/08-broadcast-upgrade/db.sqlite3 b/tests/functional/database/08-broadcast-upgrade/db.sqlite3 index 3b54d9ca731..2ac97063ca2 100644 --- a/tests/functional/database/08-broadcast-upgrade/db.sqlite3 +++ b/tests/functional/database/08-broadcast-upgrade/db.sqlite3 @@ -20,7 +20,7 @@ CREATE TABLE task_prerequisites(cycle TEXT, name TEXT, flow_nums TEXT, prereq_na CREATE TABLE task_states(name TEXT, cycle TEXT, flow_nums TEXT, time_created TEXT, time_updated TEXT, submit_num INTEGER, status TEXT, flow_wait INTEGER, is_manual_submit INTEGER, PRIMARY KEY(name, cycle, flow_nums)); INSERT INTO task_states VALUES('foo','1','[1]','2022-11-11T11:17:54Z','2022-11-11T11:17:54Z',0,'waiting',0,0); CREATE TABLE task_timeout_timers(cycle TEXT, name TEXT, timeout REAL, PRIMARY KEY(cycle, name)); -CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT); +CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT, flow TEXT); CREATE TABLE workflow_flows(flow_num INTEGER, start_time TEXT, description TEXT, PRIMARY KEY(flow_num)); INSERT INTO workflow_flows VALUES(1,'2022-11-11 11:17:54','original flow from 1'); CREATE TABLE workflow_params(key TEXT, value TEXT, PRIMARY KEY(key)); diff --git a/tests/functional/restart/57-ghost-job/db.sqlite3 b/tests/functional/restart/57-ghost-job/db.sqlite3 index 963130437fb..19661eeca5a 100644 --- a/tests/functional/restart/57-ghost-job/db.sqlite3 +++ b/tests/functional/restart/57-ghost-job/db.sqlite3 @@ -22,7 +22,7 @@ CREATE TABLE task_prerequisites(cycle TEXT, name TEXT, flow_nums TEXT, prereq_na CREATE TABLE task_states(name TEXT, cycle TEXT, flow_nums TEXT, time_created TEXT, time_updated TEXT, submit_num INTEGER, status TEXT, flow_wait INTEGER, is_manual_submit INTEGER, PRIMARY KEY(name, cycle, flow_nums)); INSERT INTO task_states VALUES('foo','1','[1]','2022-07-25T16:18:23+01:00','2022-07-25T16:18:23+01:00',1,'preparing',NULL, '0'); CREATE TABLE task_timeout_timers(cycle TEXT, name TEXT, timeout REAL, PRIMARY KEY(cycle, name)); -CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT); +CREATE TABLE tasks_to_hold(name TEXT, cycle TEXT, flow TEXT); CREATE TABLE workflow_flows(flow_num INTEGER, start_time TEXT, description TEXT, PRIMARY KEY(flow_num)); INSERT INTO workflow_flows VALUES(1,'2022-07-25 16:18:23','original flow from 1'); CREATE TABLE workflow_params(key TEXT, value TEXT, PRIMARY KEY(key)); diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py index c6f34c0adf0..26eaef4a96f 100644 --- a/tests/integration/test_task_pool.py +++ b/tests/integration/test_task_pool.py @@ -16,7 +16,15 @@ from copy import deepcopy import logging -from typing import AsyncGenerator, Callable, Iterable, List, Tuple, Union +from typing import ( + AsyncGenerator, + Callable, + Iterable, + List, + Tuple, + Union, + Optional +) import pytest from pytest import param @@ -63,11 +71,13 @@ def get_task_ids( - name_point_list: Iterable[Tuple[str, Union[PointBase, str, int]]] + name_point_list: Iterable[Tuple[str, Union[PointBase, str, int], int]] ) -> List[str]: """Helper function to return sorted task identities - from a list of (name, point) tuples.""" - return sorted(f'{point}/{name}' for name, point in name_point_list) + from a list of (name, point, flow) tuples. + Ignore flow. + """ + return sorted(f'{point}/{name}' for name, point, _ in name_point_list) def assert_expected_log( @@ -325,32 +335,45 @@ async def test_match_taskdefs( task_pool = mod_example_flow.pool n_warnings, task_items = task_pool.match_taskdefs(items) - assert get_task_ids(task_items) == sorted(expected_task_ids) + assert get_task_ids( + [(m,n,1) for m, n in task_items] + ) == sorted(expected_task_ids) logged_warnings = assert_expected_log(caplog, expected_warnings) assert n_warnings == len(logged_warnings) @pytest.mark.parametrize( - 'items, expected_tasks_to_hold_ids, expected_warnings', + 'items, flow_num, expected_tasks_to_hold_ids, expected_warnings', [ param( - ['1/foo', '3/asd'], ['1/foo', '3/asd'], [], + ['1/foo', '3/asd'], None, ['1/foo', '3/asd'], [], id="Active & future tasks" ), param( - ['1/*', '2/*', '3/*', '6/*'], + ['1/*', '2/*', '3/*', '6/*'], None, ['1/foo', '1/bar', '2/foo', '2/bar', '2/pub', '3/foo', '3/bar'], ["No active tasks matching: 6/*"], id="Name globs hold active tasks only" # (active means n=0 here) ), param( - ['1/FAM', '2/FAM', '6/FAM'], ['1/bar', '2/bar'], + ['1/*', '2/*', '3/*', '6/*'], 1, + ['1/foo', '1/bar', '2/foo', '2/bar', '2/pub', '3/foo', '3/bar'], + ["No active tasks matching: 6/*"], + id="Flow match" + ), + param( + ['1/*', '2/*', '3/*', '6/*'], 2, [], + ["No active tasks matching: 6/*"], + id="No flow match" + ), + param( + ['1/FAM', '2/FAM', '6/FAM'], None, ['1/bar', '2/bar'], ["No active tasks in the family FAM matching: 6/FAM"], id="Family names hold active tasks only" ), param( - ['1/grogu', 'H/foo', '20/foo', '1/pub'], [], + ['1/grogu', 'H/foo', '20/foo', '1/pub'], None, [], ["No matching tasks found: grogu", "H/foo - invalid cycle point: H", "Invalid cycle point for task: foo, 20", @@ -358,7 +381,7 @@ async def test_match_taskdefs( id="Non-existent task name or invalid cycle point" ), param( - ['1/foo:waiting', '1/foo:failed', '6/bar:waiting'], ['1/foo'], + ['1/foo:waiting', '1/foo:failed', '6/bar:waiting'], None, ['1/foo'], ["No active tasks matching: 1/foo:failed", "No active tasks matching: 6/bar:waiting"], id="Specifying task state works for active tasks, not future tasks" @@ -367,6 +390,7 @@ async def test_match_taskdefs( ) async def test_hold_tasks( items: List[str], + flow_num: Optional[int], expected_tasks_to_hold_ids: List[str], expected_warnings: List[str], example_flow: Scheduler, caplog: pytest.LogCaptureFixture, @@ -380,19 +404,20 @@ async def test_hold_tasks( Params: items: Arg passed to hold_tasks(). expected_tasks_to_hold_ids: Expected IDs of the tasks that get put in - the TaskPool.tasks_to_hold set, of the form "{point}/{name}"/ + the TaskPool.hold_mgr._flatten() set, of the form "{point}/{name}"/ expected_warnings: Expected to be logged. """ expected_tasks_to_hold_ids = sorted(expected_tasks_to_hold_ids) caplog.set_level(logging.WARNING, CYLC_LOG) task_pool = example_flow.pool - n_warnings = task_pool.hold_tasks(items) + n_warnings = task_pool.hold_tasks(items, flow_num) for itask in task_pool.get_all_tasks(): hold_expected = itask.identity in expected_tasks_to_hold_ids assert itask.state.is_held is hold_expected - assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids + assert get_task_ids( + task_pool.hold_mgr._flatten()) == expected_tasks_to_hold_ids logged_warnings = assert_expected_log(caplog, expected_warnings) assert n_warnings == len(logged_warnings) @@ -420,21 +445,29 @@ async def test_release_held_tasks( for itask in task_pool.get_all_tasks(): hold_expected = itask.identity in expected_tasks_to_hold_ids assert itask.state.is_held is hold_expected - assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids + assert get_task_ids(task_pool.hold_mgr._flatten()) == expected_tasks_to_hold_ids db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold') assert get_task_ids(db_tasks_to_hold) == expected_tasks_to_hold_ids # Test - task_pool.release_held_tasks(['1/foo', '3/asd']) + task_pool.release_held_tasks(['1/foo', '3/asd'], 1) # right flow, do release + for itask in task_pool.get_all_tasks(): + assert itask.state.is_held is (itask.identity == '1/bar') + + task_pool.release_held_tasks(['1/bar'], 2) # wrong flow, don't release for itask in task_pool.get_all_tasks(): assert itask.state.is_held is (itask.identity == '1/bar') - expected_tasks_to_hold_ids = sorted(['1/bar']) - assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids + expected_tasks_to_hold_ids = ['1/bar'] + assert get_task_ids(task_pool.hold_mgr._flatten()) == expected_tasks_to_hold_ids db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold') assert get_task_ids(db_tasks_to_hold) == expected_tasks_to_hold_ids + task_pool.release_held_tasks(['1/bar']) # any flow, do release + for itask in task_pool.get_all_tasks(): + assert itask.state.is_held is False + @pytest.mark.parametrize( 'hold_after_point, expected_held_task_ids', @@ -464,7 +497,7 @@ async def test_hold_point( hold_expected = itask.identity in expected_held_task_ids assert itask.state.is_held is hold_expected - assert get_task_ids(task_pool.tasks_to_hold) == expected_held_task_ids + assert get_task_ids(task_pool.hold_mgr._flatten()) == expected_held_task_ids db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold') assert get_task_ids(db_tasks_to_hold) == expected_held_task_ids @@ -478,7 +511,7 @@ async def test_hold_point( for itask in task_pool.get_all_tasks(): assert itask.state.is_held is False - assert task_pool.tasks_to_hold == set() + assert task_pool.hold_mgr._flatten() == set() assert db_select(example_flow, True, 'tasks_to_hold') == [] diff --git a/tests/unit/scripts/test_hold.py b/tests/unit/scripts/test_hold.py index dab053cc3ea..ad60678a9a5 100644 --- a/tests/unit/scripts/test_hold.py +++ b/tests/unit/scripts/test_hold.py @@ -32,6 +32,13 @@ 'opts, task_globs, expected_err', [ (Opts(), ['*'], None), + (Opts(flow_num=None), ['*'], None), + (Opts(flow_num=2), ['*'], None), + ( + Opts(flow_num='cat'), + ['*'], + (InputError, "--flow=cat: value must be integer") + ), (Opts(hold_point_string='2'), [], None), ( Opts(hold_point_string='2'), diff --git a/tests/unit/scripts/test_release.py b/tests/unit/scripts/test_release.py index c94e0387510..95aba9ab49a 100644 --- a/tests/unit/scripts/test_release.py +++ b/tests/unit/scripts/test_release.py @@ -33,6 +33,13 @@ [ (Opts(), ['*'], None), (Opts(release_all=True), [], None), + (Opts(flow_num=None), ['*'], None), + (Opts(flow_num=2), ['*'], None), + ( + Opts(flow_num='cat'), + ['*'], + (InputError, "--flow=cat: value must be integer") + ), ( Opts(release_all=True), ['*'],