Skip to content

Commit

Permalink
Add a SerialServer
Browse files Browse the repository at this point in the history
Debugging daisy can be a huge pain due to all of the multiprocessing. Using the serial server with a simple process block function makes it easy to step through with the debugger.
  • Loading branch information
pattonw committed Jul 19, 2024
1 parent 74e88ef commit cea0b86
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
1 change: 1 addition & 0 deletions daisy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .roi import Roi # noqa
from .scheduler import Scheduler # noqa
from .server import Server # noqa
from .serial_server import SerialServer # noqa
from .task import Task # noqa
from .worker import Worker # noqa
from .worker_pool import WorkerPool # noqa
70 changes: 70 additions & 0 deletions daisy/serial_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from .block import BlockStatus
from .scheduler import Scheduler
from .server_observer import ServerObservee
import logging

logger = logging.getLogger(__name__)


class SerialServer(ServerObservee):
def __init__(self):
super().__init__()

def run_blockwise(self, tasks, scheduler=None):
if scheduler is None:
scheduler = Scheduler(tasks)
else:
scheduler = scheduler

started_tasks = set()
finished_tasks = set()
all_tasks = set(task.task_id for task in tasks)
process_funcs = {task.task_id: task.process_function for task in tasks}

while True:
ready_tasks = scheduler.get_ready_tasks()
if finished_tasks == all_tasks:
break
else:
block = None
for ready_task in ready_tasks:
block = scheduler.acquire_block(ready_task.task_id)
if block is not None:
break
if block is None:
break
if block.task_id not in started_tasks:
self.notify_task_start(
block.task_id, scheduler.task_states[block.task_id]
)
started_tasks.add(block.task_id)
self.notify_acquire_block(
block.task_id, scheduler.task_states[block.task_id]
)
try:
process_funcs[block.task_id](block)
block.status = BlockStatus.SUCCESS
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise e
logger.error(f"Error processing block {block.block_id}: {e}")
block.status = BlockStatus.FAILED
self.notify_block_failure(block, e, {"worker_id": "serial"})
finally:
scheduler.release_block(block)
self.notify_release_block(
block.task_id, scheduler.task_states[block.task_id]
)

if scheduler.task_states[block.task_id].is_done():
self.notify_task_done(
block.task_id, scheduler.task_states[block.task_id]
)
finished_tasks.add(block.task_id)
started_tasks.remove(block.task_id)
del process_funcs[block.task_id]

if len(process_funcs) == 0:
return True

self.notify_server_exit()
6 changes: 3 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import daisy
import unittest
import pytest
import logging

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -9,7 +9,8 @@ def process_block(block):
print("Processing block %s" % block)


def test_basic():
@pytest.mark.parametrize("server", [daisy.Server(), daisy.SerialServer()])
def test_basic(server):

task = daisy.Task(
"test_server_task",
Expand All @@ -25,5 +26,4 @@ def test_basic():
timeout=None,
)

server = daisy.Server()
server.run_blockwise([task])

0 comments on commit cea0b86

Please sign in to comment.