diff --git a/hdfs/client.py b/hdfs/client.py index a49b513..49bf672 100644 --- a/hdfs/client.py +++ b/hdfs/client.py @@ -3,7 +3,7 @@ """WebHDFS API clients.""" -from .util import AsyncWriter, HdfsError +from .util import AsyncWriter, BoundedAsyncWriter, HdfsError from collections import deque from contextlib import closing, contextmanager from getpass import getuser @@ -477,7 +477,10 @@ def consumer(_data): raise _to_error(res) if data is None: - return AsyncWriter(consumer) + if buffersize is None: + return AsyncWriter(consumer) + else: + return BoundedAsyncWriter(consumer, buffer_size=buffersize) else: consumer(data) diff --git a/hdfs/util.py b/hdfs/util.py index 1b25801..dabc386 100644 --- a/hdfs/util.py +++ b/hdfs/util.py @@ -7,7 +7,7 @@ from shutil import rmtree from six.moves.queue import Queue from tempfile import mkstemp -from threading import Thread +from threading import Thread, Lock import logging as lg import os import os.path as osp @@ -31,6 +31,18 @@ def __init__(self, message, *args, **kwargs): self.exception = kwargs.get("exception") +def wrapped_consumer(asyncWriter, data): + """Wrapped consumer that lets us get a child's exception.""" + try: + _logger.debug('Starting consumer.') + asyncWriter._consumer(data) + except Exception as err: # pylint: disable=broad-except + _logger.exception('Exception in child.') + asyncWriter._err = err + finally: + _logger.debug('Finished consumer.') + + class AsyncWriter(object): """Asynchronous publisher-consumer. @@ -69,17 +81,6 @@ def __enter__(self): self._queue = Queue() self._err = None - def consumer(data): - """Wrapped consumer that lets us get a child's exception.""" - try: - _logger.debug('Starting consumer.') - self._consumer(data) - except Exception as err: # pylint: disable=broad-except - _logger.exception('Exception in child.') - self._err = err - finally: - _logger.debug('Finished consumer.') - def reader(queue): """Generator read by the consumer.""" while True: @@ -88,7 +89,7 @@ def reader(queue): break yield chunk - self._reader = Thread(target=consumer, args=(reader(self._queue), )) + self._reader = Thread(target=wrapped_consumer, args=(self, reader(self._queue), )) self._reader.start() _logger.debug('Started child thread.') return self @@ -136,6 +137,105 @@ def write(self, chunk): self._queue.put(chunk) +class BoundedAsyncWriter(AsyncWriter): + + """A Bounded asynchronous publisher-consumer. + + :param consumer: Function which takes a single generator as argument. + :param buffer_size: Number of entities that are buffered. When this number is exeeded, + write will block untill some of the entities are consumed + + This class extends AsyncWriter with a fixed buffer size. If the buffer size is exeeded, + writes will be blocked untill some of the buffer is consumed: + + """ + + # Expected by pandas to write csv files (https://github.com/mtth/hdfs/pull/130). + __iter__ = None + + def __init__(self, consumer, buffer_size=1024): + super().__init__(consumer) + self._content_length = 0 + self._content_max = buffer_size + self._content_lock = Lock() + + @property + def is_full(self): + return self._content_lock.locked() + + def __enter__(self): + + if self._queue: + raise ValueError('Cannot nest contexts.') + + self._queue = Queue() + self._err = None + + self._content_length = 0 + + def reader(queue): + """Generator read by the consumer.""" + while True: + chunk = queue.get() + if chunk is None: + break + + self._content_length -= len(chunk) + if self._content_lock.locked() and self._content_length < self._content_max: + _logger.debug("releasing lock from reader") + _logger.debug(f"Current buffer size: {self._content_length}") + try: + self._content_lock.release() + except RuntimeError: + pass + + yield chunk + + self._reader = Thread(target=wrapped_consumer, args=(self, reader(self._queue), )) + self._reader.start() + _logger.debug('Started child thread.') + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + _logger.debug('Exception in parent.') + if self._reader and self._reader.is_alive(): + _logger.debug('Signaling child.') + self._queue.put(None) + self._reader.join() + if self._err: + raise self._err # pylint: disable=raising-bad-type + else: + _logger.debug('Child terminated without errors.') + self._queue = None + + def write(self, chunk): + """Stream data to the underlying consumer. + + :param chunk: Bytes to write. These will be buffered in memory until the + consumer reads them. + + """ + self._content_lock.acquire() + + _logger.debug(f"produce called with {chunk}") + + if chunk: + # We skip empty chunks, otherwise they cause request to terminate the + # response stream. Note that these chunks can be produced by valid + # upstream encoders (e.g. bzip2). + self._content_length += len(chunk) + _logger.debug(f"Current buffer size: {self._content_length}") + + self._queue.put(chunk) + + if self._content_length < self._content_max and self._content_lock.locked(): + _logger.debug("releasing lock from write") + try: + self._content_lock.release() + except RuntimeError: + pass + @contextmanager def temppath(dpath=None): """Create a temporary path. diff --git a/test/test_util.py b/test/test_util.py index 6ee6b60..cd72e0c 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -5,15 +5,18 @@ from hdfs.util import * from nose.tools import eq_, ok_, raises - +import time +import threading class TestAsyncWriter(object): + AsyncWriterFactory = AsyncWriter + def test_basic(self): result = [] def consumer(gen): result.append(list(gen)) - with AsyncWriter(consumer) as writer: + with self.AsyncWriterFactory(consumer) as writer: writer.write('one') writer.write('two') eq_(result, [['one','two']]) @@ -22,7 +25,7 @@ def test_multiple_writer_uses(self): result = [] def consumer(gen): result.append(list(gen)) - writer = AsyncWriter(consumer) + writer = self.AsyncWriterFactory(consumer) with writer: writer.write('one') writer.write('two') @@ -35,10 +38,10 @@ def test_multiple_consumer_uses(self): result = [] def consumer(gen): result.append(list(gen)) - with AsyncWriter(consumer) as writer: + with self.AsyncWriterFactory(consumer) as writer: writer.write('one') writer.write('two') - with AsyncWriter(consumer) as writer: + with self.AsyncWriterFactory(consumer) as writer: writer.write('three') writer.write('four') eq_(result, [['one','two'],['three','four']]) @@ -48,7 +51,7 @@ def test_nested(self): result = [] def consumer(gen): result.append(list(gen)) - with AsyncWriter(consumer) as _writer: + with self.AsyncWriterFactory(consumer) as _writer: _writer.write('one') with _writer as writer: writer.write('two') @@ -59,7 +62,7 @@ def consumer(gen): for value in gen: if value == 'two': raise HdfsError('Yo') - with AsyncWriter(consumer) as writer: + with self.AsyncWriterFactory(consumer) as writer: writer.write('one') writer.write('two') @@ -71,9 +74,65 @@ def consumer(gen): def invalid(w): w.write('one') raise HdfsError('Ya') - with AsyncWriter(consumer) as writer: + with self.AsyncWriterFactory(consumer) as writer: invalid(writer) +import logging + +class WaitingConsumer(): + + do_consume = False + _logger = logging.getLogger(__name__) + def __init__(self): + self.values = [] + + def consume(self, gen): + for value in gen: + while not self.do_consume: + time.sleep(0.15) + WaitingConsumer._logger.debug(f"adding {value}") + self.values.append(value) + + self.do_consume = False + +class TestBoundedAsyncWriter(TestAsyncWriter): + + AsyncWriterFactory = BoundedAsyncWriter + + def test_boundness(self): + + def do_some_writing(writer): + with writer as active_writer: + active_writer.write('one') + active_writer.write('two') + active_writer.write('three') + active_writer.write('four') + + consumer = WaitingConsumer() + writer = self.AsyncWriterFactory(consumer.consume,4) + + a_thread = threading.Thread(target=do_some_writing, daemon=True, args=[writer]) + a_thread.start() # current buffer contains one and two + + consumer.do_consume=True # after consuming buffer contains two and three + time.sleep(0.4) + assert writer.is_full #== True + + consumer.do_consume=True # after consuming buffer contains only three + time.sleep(0.4) + assert writer.is_full #== True + + consumer.do_consume=True # after consuming buffer is empty but four will be inserted very quickly + time.sleep(0.4) + assert not writer.is_full #== False + + consumer.do_consume=True # after consuming buffer is empty + + time.sleep(0.4) + assert not writer.is_full #== False + + assert consumer.values == ['one', 'two', 'three', 'four'] + class TestTemppath(object):