From cb5e705d5495c933468c58361f224328a6a8fd41 Mon Sep 17 00:00:00 2001 From: Golf Player <> Date: Fri, 12 Jun 2020 22:33:29 -0500 Subject: [PATCH] Add basic hooks during execution This will enable tracking of execution process without subclassing the way papermill does. --- nbclient/client.py | 52 ++++++++++++++++++++++++++++++++++++++++++---- nbclient/util.py | 15 ++++++++++++- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index 8d383ae0..d13411ae 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -28,7 +28,7 @@ CellExecutionComplete, CellExecutionError ) -from .util import run_sync, ensure_async +from .util import run_sync, ensure_async, run_hook from .output_widget import OutputWidget @@ -227,6 +227,45 @@ class NotebookClient(LoggingConfigurable): kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.') + on_execution_start: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + Called after the kernel manager and kernel client are setup, and cells + are about to execute. + Called with kwargs `kernel_id`. + """), + ).tag(config=True) + + on_cell_start: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes before a cell is executed. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + + on_cell_complete: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes after a cell execution is complete. It is + called even when a cell results in a failure. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + + on_cell_error: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes when a cell execution results in an error. + This is executed even if errors are suppressed with `cell_allows_errors`. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + @default('kernel_manager_class') def _kernel_manager_class_default(self) -> KernelManager: """Use a dynamic default to avoid importing jupyter_client at startup""" @@ -412,6 +451,7 @@ async def async_start_new_kernel_client(self, **kwargs) -> t.Tuple[KernelClient, await self._async_cleanup_kernel() raise self.kc.allow_stdin = False + run_hook(self.on_execution_start, kernel_id=kernel_id) return self.kc, kernel_id start_new_kernel_client = run_sync(async_start_new_kernel_client) @@ -702,14 +742,16 @@ def _passed_deadline(self, deadline: int) -> bool: def _check_raise_for_error( self, cell: NotebookNode, + cell_index: int, exec_reply: t.Optional[t.Dict]) -> None: cell_allows_errors = self.allow_errors or "raises-exception" in cell.metadata.get( "tags", [] ) - if self.force_raise_errors or not cell_allows_errors: - if (exec_reply is not None) and exec_reply['content']['status'] == 'error': + if (exec_reply is not None) and exec_reply['content']['status'] == 'error': + run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) + if self.force_raise_errors or not cell_allows_errors: raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) async def async_execute_cell( @@ -760,6 +802,7 @@ async def async_execute_cell( cell['metadata']['execution'] = {} self.log.debug("Executing cell:\n%s", cell.source) + run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) parent_msg_id = await ensure_async( self.kc.execute( cell.source, @@ -767,6 +810,7 @@ async def async_execute_cell( stop_on_error=not self.allow_errors ) ) + run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) # We launched a code cell to execute self.code_cells_executed += 1 exec_timeout = self._get_timeout(cell) @@ -792,7 +836,7 @@ async def async_execute_cell( if execution_count: cell['execution_count'] = execution_count - self._check_raise_for_error(cell, exec_reply) + self._check_raise_for_error(cell, cell_index, exec_reply) self.nb['cells'][cell_index] = cell return cell diff --git a/nbclient/util.py b/nbclient/util.py index 9ac4e219..77a0ca20 100644 --- a/nbclient/util.py +++ b/nbclient/util.py @@ -6,7 +6,8 @@ import asyncio import sys import inspect -from typing import Callable, Awaitable, Any, Union +from typing import Callable, Awaitable, Any, Union, Optional +from functools import partial def check_ipython() -> None: @@ -91,3 +92,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any: return result # obj doesn't need to be awaited return obj + + +def run_hook(hook: Optional[Callable], **kwargs) -> None: + if hook is None: + return + if inspect.iscoroutinefunction(hook): + future = hook(**kwargs) + else: + loop = asyncio.get_event_loop() + hook_with_kwargs = partial(hook, **kwargs) + future = loop.run_in_executor(None, hook_with_kwargs) + asyncio.ensure_future(future)