diff --git a/asyncssh/process.py b/asyncssh/process.py index 811b4f3..d1ce484 100644 --- a/asyncssh/process.py +++ b/asyncssh/process.py @@ -65,6 +65,10 @@ MaybeAwait[None]] +_QUEUE_LOW_WATER = 8 +_QUEUE_HIGH_WATER = 16 + + class _AsyncFileProtocol(Protocol[AnyStr]): """Protocol for an async file""" @@ -304,12 +308,14 @@ class _AsyncFileWriter(_UnicodeWriter[AnyStr]): def __init__(self, process: 'SSHProcess[AnyStr]', file: _AsyncFileProtocol[bytes], needs_close: bool, - encoding: Optional[str], errors: str): + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._process: 'SSHProcess[AnyStr]' = process self._file = file self._needs_close = needs_close + self._datatype = datatype + self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._writer()) @@ -327,6 +333,10 @@ async def _writer(self) -> None: await self._file.write(self.encode(data)) self._queue.task_done() + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + if self._needs_close: await self._file.close() @@ -335,6 +345,10 @@ def write(self, data: AnyStr) -> None: self._queue.put_nowait(data) + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) + def write_eof(self) -> None: """Close output file when end of file is received""" @@ -573,12 +587,14 @@ class _StreamWriter(_UnicodeWriter[AnyStr]): def __init__(self, process: 'SSHProcess[AnyStr]', writer: asyncio.StreamWriter, recv_eof: bool, - encoding: Optional[str], errors: str): + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._writer = writer self._recv_eof = recv_eof + self._datatype = datatype + self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._feed()) @@ -597,6 +613,10 @@ async def _feed(self) -> None: await self._writer.drain() self._queue.task_done() + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + if self._recv_eof: self._writer.write_eof() @@ -605,6 +625,10 @@ def write(self, data: AnyStr) -> None: self._queue.put_nowait(data) + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) + def write_eof(self) -> None: """Write EOF to the stream""" @@ -953,7 +977,7 @@ def pipe_factory() -> _PipeWriter: writer_process.set_reader(reader, send_eof, writer_datatype) writer = _ProcessWriter[AnyStr](writer_process, writer_datatype) elif isinstance(target, asyncio.StreamWriter): - writer = _StreamWriter(self, target, recv_eof, + writer = _StreamWriter(self, target, recv_eof, datatype, self._encoding, self._errors) else: file: _File @@ -978,7 +1002,7 @@ def pipe_factory() -> _PipeWriter: inspect.isgeneratorfunction(file.write)): writer = _AsyncFileWriter( self, cast(_AsyncFileProtocol, file), needs_close, - self._encoding, self._errors) + datatype, self._encoding, self._errors) elif _is_regular_file(cast(IO[bytes], file)): writer = _FileWriter(cast(IO[bytes], file), needs_close, self._encoding, self._errors) diff --git a/tests/test_process.py b/tests/test_process.py index addc8bd..d2c24eb 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1310,6 +1310,20 @@ async def test_pause_async_file_reader(self): self.assertEqual(result.stdout, data) + @asynctest + async def test_pause_async_file_writer(self): + """Test pausing and resuming writing to an aiofile""" + + data = 4*1024*1024*'*' + + async with aiofiles.open('stdout', 'w') as file: + async with self.connect() as conn: + await conn.run('delay', input=data, stdout=file, + stderr=asyncssh.DEVNULL) + + with open('stdout', 'r') as file: + self.assertEqual(file.read(), data) + @unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows') class _TestProcessPipes(_TestProcess): @@ -1538,50 +1552,55 @@ async def test_stdout_socketpair(self): self.assertEqual(result.stderr, data) @asynctest - async def test_pause_socketpair_reader(self): - """Test pausing and resuming reading from a socketpair""" + async def test_pause_socketpair_pipes(self): + """Test pausing and resuming reading from and writing to pipes""" - data = 4*1024*1024*'*' + data = 4*1024*1024*b'*' sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - _, writer = await asyncio.open_unix_connection(sock=sock1) - writer.write(data.encode()) - writer.close() + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - async with self.connect() as conn: - result = await conn.run('delay', stdin=sock2, - stderr=asyncssh.DEVNULL) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock4) - self.assertEqual(result.stdout, data) - - @asynctest - async def test_pause_socketpair_writer(self): - """Test pausing and resuming writing to a socketpair""" + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=sock2, stdout=sock3, + stderr=asyncssh.DEVNULL) - data = 4*1024*1024*'*' + self.assertEqual((await reader2.read()), data) + await process.wait() - rsock1, wsock1 = socket.socketpair() - rsock2, wsock2 = socket.socketpair() + writer2.close() - reader1, writer1 = await asyncio.open_unix_connection(sock=rsock1) - reader2, writer2 = await asyncio.open_unix_connection(sock=rsock2) + @asynctest + async def test_pause_socketpair_streams(self): + """Test pausing and resuming reading from and writing to streams""" - async with self.connect() as conn: - process = await conn.create_process(input=data) + data = 4*1024*1024*b'*' - await asyncio.sleep(1) + sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - await process.redirect_stdout(wsock1) - await process.redirect_stderr(wsock2) + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - stdout_data, stderr_data = \ - await asyncio.gather(reader1.read(), reader2.read()) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock2) + _, writer3 = await asyncio.open_unix_connection(sock=sock3) + reader4, writer4 = await asyncio.open_unix_connection(sock=sock4) - writer1.close() - writer2.close() + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=reader2, stdout=writer3, + stderr=asyncssh.DEVNULL) + self.assertEqual((await reader4.read()), data) await process.wait() - self.assertEqual(stdout_data.decode(), data) - self.assertEqual(stderr_data.decode(), data) + writer2.close() + writer3.close() + writer4.close()