Skip to content

Commit

Permalink
Add basic hooks during execution
Browse files Browse the repository at this point in the history
This will enable tracking of execution process without subclassing the
way papermill does.
  • Loading branch information
Golf Player committed Jun 24, 2020
1 parent 202e046 commit cb5e705
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
52 changes: 48 additions & 4 deletions nbclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
CellExecutionComplete,
CellExecutionError
)
from .util import run_sync, ensure_async
from .util import run_sync, ensure_async, run_hook
from .output_widget import OutputWidget


Expand Down Expand Up @@ -227,6 +227,45 @@ class NotebookClient(LoggingConfigurable):

kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')

on_execution_start: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
Called after the kernel manager and kernel client are setup, and cells
are about to execute.
Called with kwargs `kernel_id`.
"""),
).tag(config=True)

on_cell_start: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes before a cell is executed.
Called with kwargs `cell`, and `cell_index`.
"""),
).tag(config=True)

on_cell_complete: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes after a cell execution is complete. It is
called even when a cell results in a failure.
Called with kwargs `cell`, and `cell_index`.
"""),
).tag(config=True)

on_cell_error: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes when a cell execution results in an error.
This is executed even if errors are suppressed with `cell_allows_errors`.
Called with kwargs `cell`, and `cell_index`.
"""),
).tag(config=True)

@default('kernel_manager_class')
def _kernel_manager_class_default(self) -> KernelManager:
"""Use a dynamic default to avoid importing jupyter_client at startup"""
Expand Down Expand Up @@ -412,6 +451,7 @@ async def async_start_new_kernel_client(self, **kwargs) -> t.Tuple[KernelClient,
await self._async_cleanup_kernel()
raise
self.kc.allow_stdin = False
run_hook(self.on_execution_start, kernel_id=kernel_id)
return self.kc, kernel_id

start_new_kernel_client = run_sync(async_start_new_kernel_client)
Expand Down Expand Up @@ -702,14 +742,16 @@ def _passed_deadline(self, deadline: int) -> bool:
def _check_raise_for_error(
self,
cell: NotebookNode,
cell_index: int,
exec_reply: t.Optional[t.Dict]) -> None:

cell_allows_errors = self.allow_errors or "raises-exception" in cell.metadata.get(
"tags", []
)

if self.force_raise_errors or not cell_allows_errors:
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
if self.force_raise_errors or not cell_allows_errors:
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])

async def async_execute_cell(
Expand Down Expand Up @@ -760,13 +802,15 @@ async def async_execute_cell(
cell['metadata']['execution'] = {}

self.log.debug("Executing cell:\n%s", cell.source)
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
parent_msg_id = await ensure_async(
self.kc.execute(
cell.source,
store_history=store_history,
stop_on_error=not self.allow_errors
)
)
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
# We launched a code cell to execute
self.code_cells_executed += 1
exec_timeout = self._get_timeout(cell)
Expand All @@ -792,7 +836,7 @@ async def async_execute_cell(

if execution_count:
cell['execution_count'] = execution_count
self._check_raise_for_error(cell, exec_reply)
self._check_raise_for_error(cell, cell_index, exec_reply)
self.nb['cells'][cell_index] = cell
return cell

Expand Down
15 changes: 14 additions & 1 deletion nbclient/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import asyncio
import sys
import inspect
from typing import Callable, Awaitable, Any, Union
from typing import Callable, Awaitable, Any, Union, Optional
from functools import partial


def check_ipython() -> None:
Expand Down Expand Up @@ -91,3 +92,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any:
return result
# obj doesn't need to be awaited
return obj


def run_hook(hook: Optional[Callable], **kwargs) -> None:
if hook is None:
return
if inspect.iscoroutinefunction(hook):
future = hook(**kwargs)
else:
loop = asyncio.get_event_loop()
hook_with_kwargs = partial(hook, **kwargs)
future = loop.run_in_executor(None, hook_with_kwargs)
asyncio.ensure_future(future)

0 comments on commit cb5e705

Please sign in to comment.