Skip to content

Commit

Permalink
Fix flow control for AsyncFileWriter and StreamWriter
Browse files Browse the repository at this point in the history
This commit is a follow up to f2020ed, adding proper back pressure when
output is redirected to an AsyncFileWriter or StreamWriter and data is
arriving on the SSH channel faster than these writers can consume it.
Once the queue of outstanding data begins to grow, reading from the
SSH channel will be paused to allow the queue to drain somewhat before
continuing, limiting the amount of memory needed.
  • Loading branch information
ronf committed Jul 2, 2024
1 parent 16fb0ac commit 24b60d1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 34 deletions.
32 changes: 28 additions & 4 deletions asyncssh/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
MaybeAwait[None]]


_QUEUE_LOW_WATER = 8
_QUEUE_HIGH_WATER = 16


class _AsyncFileProtocol(Protocol[AnyStr]):
"""Protocol for an async file"""

Expand Down Expand Up @@ -304,12 +308,14 @@ class _AsyncFileWriter(_UnicodeWriter[AnyStr]):

def __init__(self, process: 'SSHProcess[AnyStr]',
file: _AsyncFileProtocol[bytes], needs_close: bool,
encoding: Optional[str], errors: str):
datatype: Optional[int], encoding: Optional[str], errors: str):
super().__init__(encoding, errors, hasattr(file, 'encoding'))

self._process: 'SSHProcess[AnyStr]' = process
self._file = file
self._needs_close = needs_close
self._datatype = datatype
self._paused = False
self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue()
self._write_task: Optional[asyncio.Task[None]] = \
process.channel.get_connection().create_task(self._writer())
Expand All @@ -327,6 +333,10 @@ async def _writer(self) -> None:
await self._file.write(self.encode(data))
self._queue.task_done()

if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER:
self._process.resume_feeding(self._datatype)
self._paused = False

if self._needs_close:
await self._file.close()

Expand All @@ -335,6 +345,10 @@ def write(self, data: AnyStr) -> None:

self._queue.put_nowait(data)

if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER:
self._paused = True
self._process.pause_feeding(self._datatype)

def write_eof(self) -> None:
"""Close output file when end of file is received"""

Expand Down Expand Up @@ -573,12 +587,14 @@ class _StreamWriter(_UnicodeWriter[AnyStr]):

def __init__(self, process: 'SSHProcess[AnyStr]',
writer: asyncio.StreamWriter, recv_eof: bool,
encoding: Optional[str], errors: str):
datatype: Optional[int], encoding: Optional[str], errors: str):
super().__init__(encoding, errors)

self._process: 'SSHProcess[AnyStr]' = process
self._writer = writer
self._recv_eof = recv_eof
self._datatype = datatype
self._paused = False
self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue()
self._write_task: Optional[asyncio.Task[None]] = \
process.channel.get_connection().create_task(self._feed())
Expand All @@ -597,6 +613,10 @@ async def _feed(self) -> None:
await self._writer.drain()
self._queue.task_done()

if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER:
self._process.resume_feeding(self._datatype)
self._paused = False

if self._recv_eof:
self._writer.write_eof()

Expand All @@ -605,6 +625,10 @@ def write(self, data: AnyStr) -> None:

self._queue.put_nowait(data)

if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER:
self._paused = True
self._process.pause_feeding(self._datatype)

def write_eof(self) -> None:
"""Write EOF to the stream"""

Expand Down Expand Up @@ -953,7 +977,7 @@ def pipe_factory() -> _PipeWriter:
writer_process.set_reader(reader, send_eof, writer_datatype)
writer = _ProcessWriter[AnyStr](writer_process, writer_datatype)
elif isinstance(target, asyncio.StreamWriter):
writer = _StreamWriter(self, target, recv_eof,
writer = _StreamWriter(self, target, recv_eof, datatype,
self._encoding, self._errors)
else:
file: _File
Expand All @@ -978,7 +1002,7 @@ def pipe_factory() -> _PipeWriter:
inspect.isgeneratorfunction(file.write)):
writer = _AsyncFileWriter(
self, cast(_AsyncFileProtocol, file), needs_close,
self._encoding, self._errors)
datatype, self._encoding, self._errors)
elif _is_regular_file(cast(IO[bytes], file)):
writer = _FileWriter(cast(IO[bytes], file), needs_close,
self._encoding, self._errors)
Expand Down
79 changes: 49 additions & 30 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,20 @@ async def test_pause_async_file_reader(self):

self.assertEqual(result.stdout, data)

@asynctest
async def test_pause_async_file_writer(self):
"""Test pausing and resuming writing to an aiofile"""

data = 4*1024*1024*'*'

async with aiofiles.open('stdout', 'w') as file:
async with self.connect() as conn:
await conn.run('delay', input=data, stdout=file,
stderr=asyncssh.DEVNULL)

with open('stdout', 'r') as file:
self.assertEqual(file.read(), data)


@unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows')
class _TestProcessPipes(_TestProcess):
Expand Down Expand Up @@ -1538,50 +1552,55 @@ async def test_stdout_socketpair(self):
self.assertEqual(result.stderr, data)

@asynctest
async def test_pause_socketpair_reader(self):
"""Test pausing and resuming reading from a socketpair"""
async def test_pause_socketpair_pipes(self):
"""Test pausing and resuming reading from and writing to pipes"""

data = 4*1024*1024*'*'
data = 4*1024*1024*b'*'

sock1, sock2 = socket.socketpair()
sock3, sock4 = socket.socketpair()

_, writer = await asyncio.open_unix_connection(sock=sock1)
writer.write(data.encode())
writer.close()
_, writer1 = await asyncio.open_unix_connection(sock=sock1)
writer1.write(data)
writer1.close()

async with self.connect() as conn:
result = await conn.run('delay', stdin=sock2,
stderr=asyncssh.DEVNULL)
reader2, writer2 = await asyncio.open_unix_connection(sock=sock4)

self.assertEqual(result.stdout, data)

@asynctest
async def test_pause_socketpair_writer(self):
"""Test pausing and resuming writing to a socketpair"""
async with self.connect() as conn:
process = await conn.create_process('delay', encoding=None,
stdin=sock2, stdout=sock3,
stderr=asyncssh.DEVNULL)

data = 4*1024*1024*'*'
self.assertEqual((await reader2.read()), data)
await process.wait()

rsock1, wsock1 = socket.socketpair()
rsock2, wsock2 = socket.socketpair()
writer2.close()

reader1, writer1 = await asyncio.open_unix_connection(sock=rsock1)
reader2, writer2 = await asyncio.open_unix_connection(sock=rsock2)
@asynctest
async def test_pause_socketpair_streams(self):
"""Test pausing and resuming reading from and writing to streams"""

async with self.connect() as conn:
process = await conn.create_process(input=data)
data = 4*1024*1024*b'*'

await asyncio.sleep(1)
sock1, sock2 = socket.socketpair()
sock3, sock4 = socket.socketpair()

await process.redirect_stdout(wsock1)
await process.redirect_stderr(wsock2)
_, writer1 = await asyncio.open_unix_connection(sock=sock1)
writer1.write(data)
writer1.close()

stdout_data, stderr_data = \
await asyncio.gather(reader1.read(), reader2.read())
reader2, writer2 = await asyncio.open_unix_connection(sock=sock2)
_, writer3 = await asyncio.open_unix_connection(sock=sock3)
reader4, writer4 = await asyncio.open_unix_connection(sock=sock4)

writer1.close()
writer2.close()
async with self.connect() as conn:
process = await conn.create_process('delay', encoding=None,
stdin=reader2, stdout=writer3,
stderr=asyncssh.DEVNULL)

self.assertEqual((await reader4.read()), data)
await process.wait()

self.assertEqual(stdout_data.decode(), data)
self.assertEqual(stderr_data.decode(), data)
writer2.close()
writer3.close()
writer4.close()

0 comments on commit 24b60d1

Please sign in to comment.