diff --git a/daisy/__init__.py b/daisy/__init__.py index d650b344..1e2a63f4 100644 --- a/daisy/__init__.py +++ b/daisy/__init__.py @@ -11,6 +11,7 @@ from .roi import Roi # noqa from .scheduler import Scheduler # noqa from .server import Server # noqa +from .serial_server import SerialServer # noqa from .task import Task # noqa from .worker import Worker # noqa from .worker_pool import WorkerPool # noqa diff --git a/daisy/serial_server.py b/daisy/serial_server.py new file mode 100644 index 00000000..98b12eda --- /dev/null +++ b/daisy/serial_server.py @@ -0,0 +1,70 @@ +from .block import BlockStatus +from .scheduler import Scheduler +from .server_observer import ServerObservee +import logging + +logger = logging.getLogger(__name__) + + +class SerialServer(ServerObservee): + def __init__(self): + super().__init__() + + def run_blockwise(self, tasks, scheduler=None): + if scheduler is None: + scheduler = Scheduler(tasks) + else: + scheduler = scheduler + + started_tasks = set() + finished_tasks = set() + all_tasks = set(task.task_id for task in tasks) + process_funcs = {task.task_id: task.process_function for task in tasks} + + while True: + ready_tasks = scheduler.get_ready_tasks() + if finished_tasks == all_tasks: + break + else: + block = None + for ready_task in ready_tasks: + block = scheduler.acquire_block(ready_task.task_id) + if block is not None: + break + if block is None: + break + if block.task_id not in started_tasks: + self.notify_task_start( + block.task_id, scheduler.task_states[block.task_id] + ) + started_tasks.add(block.task_id) + self.notify_acquire_block( + block.task_id, scheduler.task_states[block.task_id] + ) + try: + process_funcs[block.task_id](block) + block.status = BlockStatus.SUCCESS + except Exception as e: + if isinstance(e, KeyboardInterrupt): + raise e + logger.error(f"Error processing block {block.block_id}: {e}") + block.status = BlockStatus.FAILED + self.notify_block_failure(block, e, {"worker_id": "serial"}) + finally: + scheduler.release_block(block) + self.notify_release_block( + block.task_id, scheduler.task_states[block.task_id] + ) + + if scheduler.task_states[block.task_id].is_done(): + self.notify_task_done( + block.task_id, scheduler.task_states[block.task_id] + ) + finished_tasks.add(block.task_id) + started_tasks.remove(block.task_id) + del process_funcs[block.task_id] + + if len(process_funcs) == 0: + return True + + self.notify_server_exit() diff --git a/tests/test_server.py b/tests/test_server.py index 1fbc27b4..184a391a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,5 @@ import daisy -import unittest +import pytest import logging logging.basicConfig(level=logging.DEBUG) @@ -9,7 +9,8 @@ def process_block(block): print("Processing block %s" % block) -def test_basic(): +@pytest.mark.parametrize("server", [daisy.Server(), daisy.SerialServer()]) +def test_basic(server): task = daisy.Task( "test_server_task", @@ -25,5 +26,4 @@ def test_basic(): timeout=None, ) - server = daisy.Server() server.run_blockwise([task])