diff --git a/src/pylibsshext/includes/sftp.pxd b/src/pylibsshext/includes/sftp.pxd index 9d3db9310..4b7598d58 100644 --- a/src/pylibsshext/includes/sftp.pxd +++ b/src/pylibsshext/includes/sftp.pxd @@ -55,5 +55,32 @@ cdef extern from "libssh/sftp.h" nogil: ssize_t sftp_read(sftp_file file, const void *buf, size_t count) int sftp_get_error(sftp_session sftp) + struct sftp_attributes_struct: + char *name + char *longname + unsigned int flags + unsigned int type + unsigned int size + # ... + ctypedef sftp_attributes_struct * sftp_attributes + sftp_attributes sftp_stat(sftp_session session, const char *path) + + struct sftp_aio_struct: + pass + ctypedef sftp_aio_struct * sftp_aio + ssize_t sftp_aio_begin_read(sftp_file file, size_t len, sftp_aio *aio) + ssize_t sftp_aio_wait_read(sftp_aio *aio, void *buf, size_t buf_size) + ssize_t sftp_aio_begin_write(sftp_file file, const void *buf, size_t len, sftp_aio *aio) + ssize_t sftp_aio_wait_write(sftp_aio *aio) + void sftp_aio_free(sftp_aio aio) + + struct sftp_limits_struct: + unsigned int max_packet_length + unsigned int max_read_length + unsigned int max_write_length + unsigned int max_open_handles + ctypedef sftp_limits_struct * sftp_limits_t + sftp_limits_t sftp_limits(sftp_session sftp) + cdef extern from "sys/stat.h" nogil: cdef int S_IRWXU diff --git a/src/pylibsshext/sftp.pxd b/src/pylibsshext/sftp.pxd index 07f1afc6b..af355ae23 100644 --- a/src/pylibsshext/sftp.pxd +++ b/src/pylibsshext/sftp.pxd @@ -23,3 +23,15 @@ from pylibsshext.session cimport Session cdef class SFTP: cdef Session session cdef sftp.sftp_session _libssh_sftp_session + +cdef class SFTP_AIO: + cdef _aio_queue + cdef _remote_file + cdef _file_size + cdef _total_bytes_requested + cdef sftp.sftp_session _sftp + cdef sftp.sftp_limits_t _limits + cdef sftp.sftp_file _rf + +cdef class C_AIO: + cdef sftp.sftp_aio aio diff --git a/src/pylibsshext/sftp.pyx b/src/pylibsshext/sftp.pyx index 6331c529d..e452b12c5 100644 --- a/src/pylibsshext/sftp.pyx +++ b/src/pylibsshext/sftp.pyx @@ -15,9 +15,12 @@ # License along with this library; if not, see file LICENSE.rst in this # repository. +import os + from posix.fcntl cimport O_CREAT, O_RDONLY, O_TRUNC, O_WRONLY from cpython.bytes cimport PyBytes_AS_STRING +from cpython.mem cimport PyMem_Free, PyMem_Malloc from pylibsshext.errors cimport LibsshSFTPException from pylibsshext.session cimport get_libssh_session @@ -54,69 +57,186 @@ cdef class SFTP: self._libssh_sftp_session = NULL def put(self, local_file, remote_file): + SFTP_AIO(self).put(local_file, remote_file) + + def get(self, remote_file, local_file): + SFTP_AIO(self).get(remote_file, local_file) + + def close(self): + if self._libssh_sftp_session is not NULL: + sftp.sftp_free(self._libssh_sftp_session) + self._libssh_sftp_session = NULL + + def _get_sftp_error_str(self): + error = sftp.sftp_get_error(self._libssh_sftp_session) + if error in MSG_MAP and error != sftp.SSH_FX_FAILURE: + return MSG_MAP[error] + return "Generic failure: %s" % self.session._get_session_error_str() + +cdef sftp.sftp_session get_sftp_session(SFTP sftp_obj): + return sftp_obj._libssh_sftp_session + +cdef class SFTP_AIO: + def __cinit__(self, SFTP sftp_obj): + self._sftp = get_sftp_session(sftp_obj) + + self._limits = sftp.sftp_limits(self._sftp) + if self._limits is NULL: + raise LibsshSFTPException("Failed to get remote SFTP limits [%s]" % (self._get_sftp_error_str())) + + def __init__(self, SFTP sftp_obj): + self._aio_queue = [] + + def __dealloc__(self): + if self._rf is not NULL: + sftp.sftp_close(self._rf) + self._rf = NULL + + def put(self, local_file, remote_file): + # reset + self._aio_queue = [] + self._total_bytes_requested = 0 + + cdef C_AIO aio cdef sftp.sftp_file rf - with open(local_file, "rb") as f: - remote_file_b = remote_file - if isinstance(remote_file_b, unicode): - remote_file_b = remote_file.encode("utf-8") + self._remote_file = remote_file - rf = sftp.sftp_open(self._libssh_sftp_session, remote_file_b, O_WRONLY | O_CREAT | O_TRUNC, sftp.S_IRWXU) - if rf is NULL: - raise LibsshSFTPException("Opening remote file [%s] for write failed with error [%s]" % (remote_file, self._get_sftp_error_str())) - buffer = f.read(1024) - - while buffer != b"": - length = len(buffer) - written = sftp.sftp_write(rf, PyBytes_AS_STRING(buffer), length) - if written != length: - sftp.sftp_close(rf) + remote_file_b = remote_file + if isinstance(remote_file_b, unicode): + remote_file_b = remote_file.encode("utf-8") + + rf = sftp.sftp_open(self._sftp, remote_file_b, O_WRONLY | O_CREAT | O_TRUNC, sftp.S_IRWXU) + if rf is NULL: + raise LibsshSFTPException("Opening remote file [%s] for write failed with error [%s]" % (remote_file, self._get_sftp_error_str())) + self._rf = rf + + with open(local_file, "rb") as f: + f.seek(0, os.SEEK_END) + self._file_size = f.tell() + f.seek(0, os.SEEK_SET) + + # open up to 10 parallel transfers + i = 0 + while i < 10 and self._total_bytes_requested < self._file_size: + self._put_chunk(f) + i += 1 + + while len(self._aio_queue): + aio = self._aio_queue.pop() + bytes_written = sftp.sftp_aio_wait_write(&aio.aio) + if bytes_written == libssh.SSH_ERROR: raise LibsshSFTPException( - "Writing to remote file [%s] failed with error [%s]" % ( - remote_file, - self._get_sftp_error_str(), - ) + "Failed to write to remote file [%s]: error [%s]" % (self._remote_file, self._get_sftp_error_str()) ) - buffer = f.read(1024) + # was freed in the wait if it did not fail + aio.aio = NULL + + # whole file read + if self._total_bytes_requested == self._file_size: + continue + + # else issue more read requests + self._put_chunk(f) + sftp.sftp_close(rf) + self._rf = NULL + + def _put_chunk(self, f): + to_write = min(self._file_size - self._total_bytes_requested, self._limits.max_write_length) + buffer = f.read(to_write) + if len(buffer) != to_write: + raise LibsshSFTPException("Read only [%d] but requested [%d] when reading from local file [%s] " % (len(buffer), to_write, self._remote_file)) + + cdef sftp.sftp_aio aio = NULL + bytes_requested = sftp.sftp_aio_begin_write(self._rf, PyBytes_AS_STRING(buffer), to_write, &aio) + if bytes_requested != to_write: + raise LibsshSFTPException("Failed to write chunk of size [%d] of file [%s] with error [%s]" + % (to_write, self._remote_file, self._get_sftp_error_str())) + self._total_bytes_requested += bytes_requested + c_aio = C_AIO() + c_aio.aio = aio + self._aio_queue.append(c_aio) def get(self, remote_file, local_file): - cdef sftp.sftp_file rf - cdef char read_buffer[1024] + # reset + self._aio_queue = [] + self._total_bytes_requested = 0 + + cdef C_AIO aio + cdef sftp.sftp_file rf = NULL + cdef sftp.sftp_attributes attrs + cdef char *buffer = NULL + self._remote_file = remote_file remote_file_b = remote_file if isinstance(remote_file_b, unicode): remote_file_b = remote_file.encode("utf-8") - rf = sftp.sftp_open(self._libssh_sftp_session, remote_file_b, O_RDONLY, sftp.S_IRWXU) - if rf is NULL: - raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" % (remote_file, self._get_sftp_error_str())) - - while True: - file_data = sftp.sftp_read(rf, read_buffer, sizeof(char) * 1024) - if file_data == 0: - break - elif file_data < 0: - sftp.sftp_close(rf) - raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]" - % (remote_file, self._get_sftp_error_str())) - - with open(local_file, 'wb+') as f: - bytes_written = f.write(read_buffer[:file_data]) - if bytes_written and file_data != bytes_written: - sftp.sftp_close(rf) - raise LibsshSFTPException("Number of bytes [%s] read from remote file [%s]" - " does not match number of bytes [%s] written to local file [%s]" - " due to error [%s]" - % (file_data, remote_file, bytes_written, local_file, self._get_sftp_error_str())) - sftp.sftp_close(rf) + attrs = sftp.sftp_stat(self._sftp, remote_file_b) + if attrs is NULL: + raise LibsshSFTPException("Failed to stat the remote file [%s] with error [%s]" + % (remote_file, self._get_sftp_error_str())) + self._file_size = attrs.size - def close(self): - if self._libssh_sftp_session is not NULL: - sftp.sftp_free(self._libssh_sftp_session) - self._libssh_sftp_session = NULL + buffer_size = min(self._limits.max_read_length, self._file_size) + try: + buffer = PyMem_Malloc(buffer_size) - def _get_sftp_error_str(self): - error = sftp.sftp_get_error(self._libssh_sftp_session) - if error in MSG_MAP and error != sftp.SSH_FX_FAILURE: - return MSG_MAP[error] - return "Generic failure: %s" % self.session._get_session_error_str() + rf = sftp.sftp_open(self._sftp, remote_file_b, O_RDONLY, sftp.S_IRWXU) + if rf is NULL: + raise LibsshSFTPException("Opening remote file [%s] for reading failed with error [%s]" % (remote_file, self._get_sftp_error_str())) + self._rf = rf + + with open(local_file, 'wb') as f: + # open up to 10 parallel transfers + i = 0 + while i < 10 and self._total_bytes_requested < self._file_size: + self._get_chunk() + i += 1 + + while len(self._aio_queue): + aio = self._aio_queue.pop() + bytes_read = sftp.sftp_aio_wait_read(&aio.aio, buffer, buffer_size) + if bytes_read == libssh.SSH_ERROR: + raise LibsshSFTPException( + "Failed to read from remote file [%s]: error [%s]" % (self._remote_file, self._get_sftp_error_str()) + ) + # was freed in the wait if it did not fail -- otherwise the __dealloc__ will free it + aio.aio = NULL + + # write the file + f.write(buffer) + + # whole file read + if self._total_bytes_requested == self._file_size: + continue + + # else issue more read requests + self._get_chunk() + + finally: + if buffer is not NULL: + PyMem_Free(buffer) + sftp.sftp_close(rf) + self._rf = NULL + + def _get_chunk(self): + to_read = min(self._file_size - self._total_bytes_requested, self._limits.max_read_length) + cdef sftp.sftp_aio aio = NULL + bytes_requested = sftp.sftp_aio_begin_read(self._rf, to_read, &aio) + if bytes_requested != to_read: + raise LibsshSFTPException("Failed to request to read chunk of size [%d] of file [%s] with error [%s]" + % (to_read, self._remote_file, self._get_sftp_error_str())) + self._total_bytes_requested += bytes_requested + c_aio = C_AIO() + c_aio.aio = aio + self._aio_queue.append(c_aio) + + +cdef class C_AIO: + def __cinit__(self): + self.aio = NULL + + def __dealloc__(self): + sftp.sftp_aio_free(self.aio) + self.aio = NULL