Skip to content

Commit

Permalink
Refactoring of StdoutProxy to take app_session from the caller
Browse files Browse the repository at this point in the history
This changes makes StdoutProxy to take the exact app_session from
by taking it as optional argument in the constructor instead finding
it by calling get_app_session() globally.

It can avoid confliction of printing output within the patch_stdout
context in the ssh session when the multiple ssh connections are
performing concurrently.

The changed StdoutProxy now passes the correct app instance from the
given app_session in the constructor to the run_in_terminal() in it
instead of calling get_app_or_none() globally that can give wrong
app instance from the prompt session in the last ssh connection.
  • Loading branch information
jooncheol committed Feb 28, 2024
1 parent 465ab02 commit b195ddc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
24 changes: 24 additions & 0 deletions examples/ssh/asyncssh-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.contrib.ssh import PromptToolkitSSHServer, PromptToolkitSSHSession
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.patch_stdout import StdoutProxy
from prompt_toolkit.shortcuts import ProgressBar, print_formatted_text
from prompt_toolkit.shortcuts.dialogs import input_dialog, yes_no_dialog
from prompt_toolkit.shortcuts.prompt import PromptSession
Expand Down Expand Up @@ -99,6 +100,29 @@ async def interact(ssh_session: PromptToolkitSSHSession) -> None:
await prompt_session.prompt_async("Showing input dialog... [ENTER]")
await input_dialog("Input dialog", "Running over asyncssh").run_async()

async def print_counter(output):
"""
Coroutine that prints counters.
"""
try:
i = 0
while True:
output.write(f"Counter: {i}\n")
i += 1
await asyncio.sleep(3)
except asyncio.CancelledError:
print("Background task cancelled.")

with StdoutProxy(app_session=prompt_session) as output:
background_task = asyncio.create_task(print_counter(output))
try:
text = await prompt_session.prompt_async(
"Type something with background task: "
)
output.write(f"You typed: {text}\n")
finally:
background_task.cancel()


async def main(port=8222):
# Set up logging.
Expand Down
15 changes: 11 additions & 4 deletions src/prompt_toolkit/application/run_in_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@


def run_in_terminal(
func: Callable[[], _T], render_cli_done: bool = False, in_executor: bool = False
func: Callable[[], _T],
render_cli_done: bool = False,
in_executor: bool = False,
app=None,
) -> Awaitable[_T]:
"""
Run function on the terminal above the current application or prompt.
Expand All @@ -40,12 +43,13 @@ def run_in_terminal(
erase the interface first.
:param in_executor: When True, run in executor. (Use this for long
blocking functions, when you don't want to block the event loop.)
:param app: instance of Application. (default None)
:returns: A `Future`.
"""

async def run() -> _T:
async with in_terminal(render_cli_done=render_cli_done):
async with in_terminal(render_cli_done=render_cli_done, app=app):
if in_executor:
return await run_in_executor_with_context(func)
else:
Expand All @@ -55,7 +59,9 @@ async def run() -> _T:


@asynccontextmanager
async def in_terminal(render_cli_done: bool = False) -> AsyncGenerator[None, None]:
async def in_terminal(
render_cli_done: bool = False, app=None
) -> AsyncGenerator[None, None]:
"""
Asynchronous context manager that suspends the current application and runs
the body in the terminal.
Expand All @@ -67,7 +73,8 @@ async def f():
call_some_function()
await call_some_async_function()
"""
app = get_app_or_none()
if not app:
app = get_app_or_none()
if app is None or not app._is_running:
yield
return
Expand Down
10 changes: 8 additions & 2 deletions src/prompt_toolkit/patch_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from .application import get_app_session, run_in_terminal
from .output import Output
from .shortcuts.prompt import PromptSession

__all__ = [
"patch_stdout",
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
self,
sleep_between_writes: float = 0.2,
raw: bool = False,
app_session: PromptSession = None,
) -> None:
self.sleep_between_writes = sleep_between_writes
self.raw = raw
Expand All @@ -103,7 +105,9 @@ def __init__(
self._buffer: list[str] = []

# Keep track of the curret app session.
self.app_session = get_app_session()
self.app_session = app_session
if not self.app_session:
self.app_session = get_app_session()

# See what output is active *right now*. We should do it at this point,
# before this `StdoutProxy` instance is possibly assigned to `sys.stdout`.
Expand Down Expand Up @@ -220,7 +224,9 @@ def write_and_flush() -> None:
def write_and_flush_in_loop() -> None:
# If an application is running, use `run_in_terminal`, otherwise
# call it directly.
run_in_terminal(write_and_flush, in_executor=False)
run_in_terminal(
write_and_flush, in_executor=False, app=self.app_session.app
)

if loop is None:
# No loop, write immediately.
Expand Down

0 comments on commit b195ddc

Please sign in to comment.