diff --git a/cytoolz/__init__.pxd b/cytoolz/__init__.pxd index 4468950..7e08db8 100644 --- a/cytoolz/__init__.pxd +++ b/cytoolz/__init__.pxd @@ -2,7 +2,7 @@ from cytoolz.itertoolz cimport ( accumulate, c_merge_sorted, cons, count, drop, get, groupby, first, frequencies, interleave, interpose, isdistinct, isiterable, iterate, last, mapcat, nth, partition, partition_all, pluck, reduceby, remove, - rest, second, sliding_window, take, take_nth, unique) + rest, second, sliding_window, take, take_nth, unique, join) from cytoolz.functoolz cimport ( diff --git a/cytoolz/itertoolz.pxd b/cytoolz/itertoolz.pxd index 1cf6a17..87644f0 100644 --- a/cytoolz/itertoolz.pxd +++ b/cytoolz/itertoolz.pxd @@ -148,3 +148,39 @@ cdef class _pluck_list_default: cpdef object pluck(object ind, object seqs, object default=*) + +cpdef object join(object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=*, + object right_default=*) + +cdef class _join: + cdef Py_ssize_t n + cdef object iterseq + cdef object leftkey + cdef object leftseq + cdef object rightkey + cdef object rightseq + cdef object matches + cdef object right + cdef object key + cdef object d + cdef object d_items + cdef object seen_keys + cdef object is_rightseq_exhausted + cdef object left_default + cdef object right_default + cdef int i + cdef object keys + +cdef class _inner_join(_join): + pass + +cdef class _right_outer_join(_join): + pass + +cdef class _left_outer_join(_join): + pass + +cdef class _outer_join(_join): + pass diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index 0c4c4cf..dab5913 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -1062,3 +1062,264 @@ cpdef object pluck(object ind, object seqs, object default=no_default): if default is no_default: return _pluck_index(ind, seqs) return _pluck_index_default(ind, seqs, default) + + +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + else: + return itemgetter(*index) + else: + return itemgetter(index) + +cpdef object join(object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=no_default, + object right_default=no_default): + """ Join two sequences on common attributes + + This is a semi-streaming operation. The LEFT sequence is fully evaluated + and placed into memory. The RIGHT sequence is evaluated lazily and so can + be arbitrarily large. + + >>> friends = [('Alice', 'Edith'), + ... ('Alice', 'Zhao'), + ... ('Edith', 'Alice'), + ... ('Zhao', 'Alice'), + ... ('Zhao', 'Edith')] + + >>> cities = [('Alice', 'NYC'), + ... ('Alice', 'Chicago'), + ... ('Dan', 'Syndey'), + ... ('Edith', 'Paris'), + ... ('Edith', 'Berlin'), + ... ('Zhao', 'Shanghai')] + + >>> # Vacation opportunities + >>> # In what cities do people have friends? + >>> result = join(second, friends, + ... first, cities) + >>> for ((a, b), (c, d)) in sorted(unique(result)): + ... print((a, d)) + ('Alice', 'Berlin') + ('Alice', 'Paris') + ('Alice', 'Shanghai') + ('Edith', 'Chicago') + ('Edith', 'NYC') + ('Zhao', 'Chicago') + ('Zhao', 'NYC') + ('Zhao', 'Berlin') + ('Zhao', 'Paris') + + Specify outer joins with keyword arguments ``left_default`` and/or + ``right_default``. Here is a full outer join in which unmatched elements + are paired with None. + + >>> identity = lambda x: x + >>> list(join(identity, [1, 2, 3], + ... identity, [2, 3, 4], + ... left_default=None, right_default=None)) + [(2, 2), (3, 3), (None, 4), (1, None)] + + Usually the key arguments are callables to be applied to the sequences. If + the keys are not obviously callable then it is assumed that indexing was + intended, e.g. the following is a legal change + + >>> # result = join(second, friends, first, cities) + >>> result = join(1, friends, 0, cities) # doctest: +SKIP + """ + return _join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + if left_default == no_default and right_default == no_default: + return _inner_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + elif left_default != no_default and right_default == no_default: + return _right_outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + elif left_default == no_default and right_default != no_default: + return _left_outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + else: + return _outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + +cdef class _join: + def __init__(self, + object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=no_default, + object right_default=no_default): + if not callable(leftkey): + leftkey = getter(leftkey) + if not callable(rightkey): + rightkey = getter(rightkey) + + self.left_default = left_default + self.right_default = right_default + + self.leftkey = leftkey + self.rightkey = rightkey + self.rightseq = iter(rightseq) + + self.d = groupby(leftkey, leftseq) + self.seen_keys = set() + self.matches = () + self.right = None + + self.is_rightseq_exhausted = False + + + def __iter__(self): + return self + + def __next__(self): + cdef PyObject *obj + if not self.is_rightseq_exhausted: + if self.i == len(self.matches): + try: + self.right = next(self.rightseq) + except StopIteration: + if self.right_default is no_default: + raise + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) + self.seen_keys.add(key) + obj = PyDict_GetItem(self.d, key) + if obj is NULL: + if self.left_default is not no_default: + return (self.left_default, self.right) + else: + return next(self) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + elif self.right_default is not no_default: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) + + +cdef class _right_outer_join(_join): + def __next__(self): + cdef PyObject *obj + if self.i == len(self.matches): + self.right = next(self.rightseq) + key = self.rightkey(self.right) + obj = PyDict_GetItem(self.d, key) + if obj is NULL: + return (self.left_default, self.right) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + +cdef class _outer_join(_join): + def __next__(self): + cdef PyObject *obj + if not self.is_rightseq_exhausted: + if self.i == len(self.matches): + try: + self.right = next(self.rightseq) + except StopIteration: + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) + self.seen_keys.add(key) + obj = PyDict_GetItem(self.d, key) + if obj is NULL: + return (self.left_default, self.right) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + else: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) + + + +cdef class _left_outer_join(_join): + def __next__(self): + cdef PyObject *obj + if not self.is_rightseq_exhausted: + if self.i == len(self.matches): + obj = NULL + while obj is NULL: + try: + self.right = next(self.rightseq) + except StopIteration: + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) + self.seen_keys.add(key) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + else: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) + + +cdef class _inner_join(_join): + def __next__(self): + cdef PyObject *obj = NULL + if self.i == len(self.matches): + while obj is NULL: + self.right = next(self.rightseq) + key = self.rightkey(self.right) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + +# I find `_consume` convenient for benchmarking. Perhaps this belongs +# elsewhere, so it is private (leading underscore) and hidden away for now. + +cpdef object _consume(object seq): + """ + Efficiently consume an iterable """ + for _ in seq: + pass diff --git a/cytoolz/tests/test_itertoolz.py b/cytoolz/tests/test_itertoolz.py index 788c69b..681b4d6 100644 --- a/cytoolz/tests/test_itertoolz.py +++ b/cytoolz/tests/test_itertoolz.py @@ -10,7 +10,7 @@ rest, last, cons, frequencies, reduceby, iterate, accumulate, sliding_window, count, partition, - partition_all, take_nth, pluck) + partition_all, take_nth, pluck, join) from cytoolz.compatibility import range, filter from operator import add, mul @@ -277,3 +277,95 @@ def test_pluck(): assert raises(IndexError, lambda: list(pluck(1, [[0]]))) assert raises(KeyError, lambda: list(pluck('name', [{'id': 1}]))) + + +def test_join(): + names = [(1, 'one'), (2, 'two'), (3, 'three')] + fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)] + + def addpair(pair): + return pair[0] + pair[1] + + result = set(starmap(add, join(first, names, second, fruit))) + + expected = set([((1, 'one', 'apple', 1)), + ((1, 'one', 'orange', 1)), + ((2, 'two', 'banana', 2)), + ((2, 'two', 'coconut', 2))]) + + print(result) + print(expected) + assert result == expected + + +def test_key_as_getter(): + squares = [(i, i**2) for i in range(5)] + pows = [(i, i**2, i**3) for i in range(5)] + + assert set(join(0, squares, 0, pows)) == set(join(lambda x: x[0], squares, + lambda x: x[0], pows)) + + get = lambda x: (x[0], x[1]) + assert set(join([0, 1], squares, [0, 1], pows)) == set(join(get, squares, + get, pows)) + + get = lambda x: (x[0],) + assert set(join([0], squares, [0], pows)) == set(join(get, squares, + get, pows)) + + +def test_join_double_repeats(): + names = [(1, 'one'), (2, 'two'), (3, 'three'), (1, 'uno'), (2, 'dos')] + fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)] + + result = set(starmap(add, join(first, names, second, fruit))) + + expected = set([((1, 'one', 'apple', 1)), + ((1, 'one', 'orange', 1)), + ((2, 'two', 'banana', 2)), + ((2, 'two', 'coconut', 2)), + ((1, 'uno', 'apple', 1)), + ((1, 'uno', 'orange', 1)), + ((2, 'dos', 'banana', 2)), + ((2, 'dos', 'coconut', 2))]) + + print(result) + print(expected) + assert result == expected + + +def test_join_missing_element(): + names = [(1, 'one'), (2, 'two'), (3, 'three')] + fruit = [('apple', 5), ('orange', 1)] + + result = list(join(first, names, second, fruit)) + print(result) + result = set(starmap(add, result)) + + expected = set([((1, 'one', 'orange', 1))]) + + assert result == expected + + +def test_left_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], left_default=None)) + expected = set([(2, 2), (None, 3)]) + + print(result) + print(expected) + assert result == expected + + +def test_right_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], right_default=None)) + expected = set([(2, 2), (1, None)]) + + assert result == expected + + +def test_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], + left_default=None, right_default=None)) + expected = set([(2, 2), (1, None), (None, 3)]) + + assert result == expected