From dd2d95f38d794b371f0e2be11a4b38148ce9ba28 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 5 Jun 2014 16:40:27 -0700 Subject: [PATCH 1/5] add join This is ugly and likely inefficient. Help. --- cytoolz/itertoolz.pxd | 17 ++++ cytoolz/itertoolz.pyx | 135 ++++++++++++++++++++++++++++++++ cytoolz/tests/test_itertoolz.py | 95 +++++++++++++++++++++- 3 files changed, 246 insertions(+), 1 deletion(-) diff --git a/cytoolz/itertoolz.pxd b/cytoolz/itertoolz.pxd index 910b7ea..0bd44e5 100644 --- a/cytoolz/itertoolz.pxd +++ b/cytoolz/itertoolz.pxd @@ -148,3 +148,20 @@ cdef class _pluck_list_default: cpdef object pluck(object ind, object seqs, object default=*) + +cdef class join: + cdef Py_ssize_t n + cdef object iterseq + cdef object leftkey + cdef object leftseq + cdef object rightkey + cdef object rightseq + cdef object matches + cdef object right + cdef object key + cdef object d + cdef object d_items + cdef object seen_keys + cdef object is_rightseq_exhausted + cdef object left_default + cdef object right_default diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index de1b598..b90303d 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -1061,6 +1061,141 @@ cpdef object pluck(object ind, object seqs, object default=no_default): return _pluck_index_default(ind, seqs, default) +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + else: + return itemgetter(*index) + else: + return itemgetter(index) + + +cdef class join: + """ Join two sequences on common attributes + + This is a semi-streaming operation. The LEFT sequence is fully evaluated + and placed into memory. The RIGHT sequence is evaluated lazily and so can + be arbitrarily large. + + >>> friends = [('Alice', 'Edith'), + ... ('Alice', 'Zhao'), + ... ('Edith', 'Alice'), + ... ('Zhao', 'Alice'), + ... ('Zhao', 'Edith')] + + >>> cities = [('Alice', 'NYC'), + ... ('Alice', 'Chicago'), + ... ('Dan', 'Syndey'), + ... ('Edith', 'Paris'), + ... ('Edith', 'Berlin'), + ... ('Zhao', 'Shanghai')] + + >>> # Vacation opportunities + >>> # In what cities do people have friends? + >>> result = join(second, friends, + ... first, cities) + >>> for ((a, b), (c, d)) in sorted(unique(result)): + ... print((a, d)) + ('Alice', 'Berlin') + ('Alice', 'Paris') + ('Alice', 'Shanghai') + ('Edith', 'Chicago') + ('Edith', 'NYC') + ('Zhao', 'Chicago') + ('Zhao', 'NYC') + ('Zhao', 'Berlin') + ('Zhao', 'Paris') + + Specify outer joins with keyword arguments ``left_default`` and/or + ``right_default``. Here is a full outer join in which unmatched elements + are paired with None. + + >>> identity = lambda x: x + >>> list(join(identity, [1, 2, 3], + ... identity, [2, 3, 4], + ... left_default=None, right_default=None)) + [(2, 2), (3, 3), (None, 4), (1, None)] + + Usually the key arguments are callables to be applied to the sequences. If + the keys are not obviously callable then it is assumed that indexing was + intended, e.g. the following is a legal change + + >>> # result = join(second, friends, first, cities) + >>> result = join(1, friends, 0, cities) # doctest: +SKIP + """ + def __init__(self, + object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=no_default, + object right_default=no_default): + if not callable(leftkey): + leftkey = getter(leftkey) + if not callable(rightkey): + rightkey = getter(rightkey) + + self.left_default = left_default + self.right_default = right_default + + self.leftkey = leftkey + self.rightkey = rightkey + self.rightseq = iter(rightseq) + + self.d = groupby(leftkey, leftseq) + self.seen_keys = set() + self.matches = iter(()) + self.right = None + + self.is_rightseq_exhausted = False + + + def __iter__(self): + return self + + def __next__(self): + if not self.is_rightseq_exhausted: + try: + match = next(self.matches) + return (match, self.right) + except StopIteration: # iterator of matches exhausted + try: + item = next(self.rightseq) # get a new item + except StopIteration: # no items, switch to outer join + self.is_rightseq_exhausted = True + if self.right_default is not no_default: + self.d_items = iter(self.d.items()) + self.matches = iter(()) + return next(self) + else: + raise + + key = self.rightkey(item) + self.seen_keys.add(key) + + try: + self.matches = iter(self.d[key]) # get left matches + except KeyError: + if self.left_default is not no_default: + return (self.left_default, item) # outer join + + self.right = item + return next(self) + + else: # we've exhausted the right sequence, lets iterate over unseen + # items on the left + try: + match = next(self.matches) + return (match, self.right_default) + except StopIteration: + key, matches = next(self.d_items) + while(key in self.seen_keys and matches): + key, matches = next(self.d_items) + self.key = key + self.matches = iter(matches) + return next(self) + + # I find `_consume` convenient for benchmarking. Perhaps this belongs # elsewhere, so it is private (leading underscore) and hidden away for now. diff --git a/cytoolz/tests/test_itertoolz.py b/cytoolz/tests/test_itertoolz.py index a3404ff..7a16123 100644 --- a/cytoolz/tests/test_itertoolz.py +++ b/cytoolz/tests/test_itertoolz.py @@ -1,4 +1,5 @@ import itertools +from itertools import starmap from cytoolz.utils import raises from functools import partial from cytoolz.itertoolz import (remove, groupby, merge_sorted, @@ -9,7 +10,7 @@ rest, last, cons, frequencies, reduceby, iterate, accumulate, sliding_window, count, partition, - partition_all, take_nth, pluck) + partition_all, take_nth, pluck, join) from cytoolz.compatibility import range, filter from operator import add, mul @@ -264,3 +265,95 @@ def test_pluck(): assert raises(IndexError, lambda: list(pluck(1, [[0]]))) assert raises(KeyError, lambda: list(pluck('name', [{'id': 1}]))) + + +def test_join(): + names = [(1, 'one'), (2, 'two'), (3, 'three')] + fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)] + + def addpair(pair): + return pair[0] + pair[1] + + result = set(starmap(add, join(first, names, second, fruit))) + + expected = set([((1, 'one', 'apple', 1)), + ((1, 'one', 'orange', 1)), + ((2, 'two', 'banana', 2)), + ((2, 'two', 'coconut', 2))]) + + print(result) + print(expected) + assert result == expected + + +def test_key_as_getter(): + squares = [(i, i**2) for i in range(5)] + pows = [(i, i**2, i**3) for i in range(5)] + + assert set(join(0, squares, 0, pows)) == set(join(lambda x: x[0], squares, + lambda x: x[0], pows)) + + get = lambda x: (x[0], x[1]) + assert set(join([0, 1], squares, [0, 1], pows)) == set(join(get, squares, + get, pows)) + + get = lambda x: (x[0],) + assert set(join([0], squares, [0], pows)) == set(join(get, squares, + get, pows)) + + +def test_join_double_repeats(): + names = [(1, 'one'), (2, 'two'), (3, 'three'), (1, 'uno'), (2, 'dos')] + fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)] + + result = set(starmap(add, join(first, names, second, fruit))) + + expected = set([((1, 'one', 'apple', 1)), + ((1, 'one', 'orange', 1)), + ((2, 'two', 'banana', 2)), + ((2, 'two', 'coconut', 2)), + ((1, 'uno', 'apple', 1)), + ((1, 'uno', 'orange', 1)), + ((2, 'dos', 'banana', 2)), + ((2, 'dos', 'coconut', 2))]) + + print(result) + print(expected) + assert result == expected + + +def test_join_missing_element(): + names = [(1, 'one'), (2, 'two'), (3, 'three')] + fruit = [('apple', 5), ('orange', 1)] + + result = list(join(first, names, second, fruit)) + print(result) + result = set(starmap(add, result)) + + expected = set([((1, 'one', 'orange', 1))]) + + assert result == expected + + +def test_left_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], left_default=None)) + expected = set([(2, 2), (None, 3)]) + + print(result) + print(expected) + assert result == expected + + +def test_right_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], right_default=None)) + expected = set([(2, 2), (1, None)]) + + assert result == expected + + +def test_outer_join(): + result = set(join(identity, [1, 2], identity, [2, 3], + left_default=None, right_default=None)) + expected = set([(2, 2), (1, None), (None, 3)]) + + assert result == expected From 919a97ca92d10869850df5939eff4d2f44076e97 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Jun 2014 07:37:50 -0700 Subject: [PATCH 2/5] split out different joins into different classes I'm not sure that this is worth it --- cytoolz/itertoolz.pxd | 21 ++++++- cytoolz/itertoolz.pyx | 139 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 4 deletions(-) diff --git a/cytoolz/itertoolz.pxd b/cytoolz/itertoolz.pxd index 0bd44e5..99ec432 100644 --- a/cytoolz/itertoolz.pxd +++ b/cytoolz/itertoolz.pxd @@ -149,7 +149,12 @@ cdef class _pluck_list_default: cpdef object pluck(object ind, object seqs, object default=*) -cdef class join: +cpdef object join(object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=*, + object right_default=*) + +cdef class _join: cdef Py_ssize_t n cdef object iterseq cdef object leftkey @@ -165,3 +170,17 @@ cdef class join: cdef object is_rightseq_exhausted cdef object left_default cdef object right_default + +cdef class _inner_join(_join): + cdef int i + +cdef class _right_outer_join(_join): + cdef int i + +cdef class _left_outer_join(_join): + cdef int i + cdef object keys + +cdef class _outer_join(_join): + cdef int i + cdef object keys diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index b90303d..0f40ac3 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -1071,8 +1071,10 @@ def getter(index): else: return itemgetter(index) - -cdef class join: +cpdef object join(object leftkey, object leftseq, + object rightkey, object rightseq, + object left_default=no_default, + object right_default=no_default): """ Join two sequences on common attributes This is a semi-streaming operation. The LEFT sequence is fully evaluated @@ -1125,6 +1127,20 @@ cdef class join: >>> # result = join(second, friends, first, cities) >>> result = join(1, friends, 0, cities) # doctest: +SKIP """ + if left_default == no_default and right_default == no_default: + return _inner_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + elif left_default != no_default and right_default == no_default: + return _right_outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + elif left_default == no_default and right_default != no_default: + return _left_outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + else: + return _outer_join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) + +cdef class _join: def __init__(self, object leftkey, object leftseq, object rightkey, object rightseq, @@ -1172,7 +1188,6 @@ cdef class join: key = self.rightkey(item) self.seen_keys.add(key) - try: self.matches = iter(self.d[key]) # get left matches except KeyError: @@ -1196,6 +1211,124 @@ cdef class join: return next(self) +cdef class _right_outer_join(_join): + def __iter__(self): + self.matches = () + return self + + def __next__(self): + cdef PyObject *obj + if self.i == len(self.matches): + self.right = next(self.rightseq) + key = self.rightkey(self.right) + obj = PyDict_GetItem(self.d, key) + if obj is NULL: + return (self.left_default, self.right) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + +cdef class _outer_join(_join): + def __iter__(self): + self.matches = () + return self + + def __next__(self): + cdef PyObject *obj + if not self.is_rightseq_exhausted: + if self.i == len(self.matches): + try: + self.right = next(self.rightseq) + except StopIteration: + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) + self.seen_keys.add(key) + obj = PyDict_GetItem(self.d, key) + if obj is NULL: + return (self.left_default, self.right) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + else: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) + + + +cdef class _left_outer_join(_join): + def __iter__(self): + self.matches = () + return self + + def __next__(self): + cdef PyObject *obj + if not self.is_rightseq_exhausted: + if self.i == len(self.matches): + obj = NULL + while obj is NULL: + try: + self.right = next(self.rightseq) + except StopIteration: + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) + self.seen_keys.add(key) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + else: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) + + +cdef class _inner_join(_join): + def __iter__(self): + self.matches = () + return self + + def __next__(self): + cdef PyObject *obj = NULL + if self.i == len(self.matches): + while obj is NULL: + self.right = next(self.rightseq) + key = self.rightkey(self.right) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) + + # I find `_consume` convenient for benchmarking. Perhaps this belongs # elsewhere, so it is private (leading underscore) and hidden away for now. From 1734d7c01cb310a81e2ccd25ac52cc563dec7e68 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Jun 2014 07:55:32 -0700 Subject: [PATCH 3/5] single join class with single __next__ --- cytoolz/itertoolz.pxd | 12 +++---- cytoolz/itertoolz.pyx | 79 ++++++++++++++++++------------------------- 2 files changed, 39 insertions(+), 52 deletions(-) diff --git a/cytoolz/itertoolz.pxd b/cytoolz/itertoolz.pxd index 99ec432..be4798b 100644 --- a/cytoolz/itertoolz.pxd +++ b/cytoolz/itertoolz.pxd @@ -170,17 +170,17 @@ cdef class _join: cdef object is_rightseq_exhausted cdef object left_default cdef object right_default + cdef int i + cdef object keys cdef class _inner_join(_join): - cdef int i + pass cdef class _right_outer_join(_join): - cdef int i + pass cdef class _left_outer_join(_join): - cdef int i - cdef object keys + pass cdef class _outer_join(_join): - cdef int i - cdef object keys + pass diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index 0f40ac3..ef07e6f 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -1127,6 +1127,8 @@ cpdef object join(object leftkey, object leftseq, >>> # result = join(second, friends, first, cities) >>> result = join(1, friends, 0, cities) # doctest: +SKIP """ + return _join(leftkey, leftseq, rightkey, rightseq, + left_default, right_default) if left_default == no_default and right_default == no_default: return _inner_join(leftkey, leftseq, rightkey, rightseq, left_default, right_default) @@ -1160,7 +1162,7 @@ cdef class _join: self.d = groupby(leftkey, leftseq) self.seen_keys = set() - self.matches = iter(()) + self.matches = () self.right = None self.is_rightseq_exhausted = False @@ -1170,52 +1172,45 @@ cdef class _join: return self def __next__(self): + cdef PyObject *obj if not self.is_rightseq_exhausted: - try: - match = next(self.matches) - return (match, self.right) - except StopIteration: # iterator of matches exhausted + if self.i == len(self.matches): try: - item = next(self.rightseq) # get a new item - except StopIteration: # no items, switch to outer join - self.is_rightseq_exhausted = True - if self.right_default is not no_default: - self.d_items = iter(self.d.items()) - self.matches = iter(()) - return next(self) - else: + self.right = next(self.rightseq) + except StopIteration: + if self.right_default is no_default: raise - - key = self.rightkey(item) + self.is_rightseq_exhausted = True + self.keys = iter(self.d) + return next(self) + key = self.rightkey(self.right) self.seen_keys.add(key) - try: - self.matches = iter(self.d[key]) # get left matches - except KeyError: + obj = PyDict_GetItem(self.d, key) + if obj is NULL: if self.left_default is not no_default: - return (self.left_default, item) # outer join - - self.right = item - return next(self) + return (self.left_default, self.right) + else: + return next(self) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right) - else: # we've exhausted the right sequence, lets iterate over unseen - # items on the left - try: - match = next(self.matches) - return (match, self.right_default) - except StopIteration: - key, matches = next(self.d_items) - while(key in self.seen_keys and matches): - key, matches = next(self.d_items) - self.key = key - self.matches = iter(matches) - return next(self) + elif self.right_default is not no_default: + if self.i == len(self.matches): + key = next(self.keys) + while key in self.seen_keys: + key = next(self.keys) + obj = PyDict_GetItem(self.d, key) + self.matches = obj + self.i = 0 + match = PyList_GET_ITEM(self.matches, self.i) # skip error checking + self.i += 1 + return (match, self.right_default) cdef class _right_outer_join(_join): - def __iter__(self): - self.matches = () - return self - def __next__(self): cdef PyObject *obj if self.i == len(self.matches): @@ -1232,10 +1227,6 @@ cdef class _right_outer_join(_join): cdef class _outer_join(_join): - def __iter__(self): - self.matches = () - return self - def __next__(self): cdef PyObject *obj if not self.is_rightseq_exhausted: @@ -1272,10 +1263,6 @@ cdef class _outer_join(_join): cdef class _left_outer_join(_join): - def __iter__(self): - self.matches = () - return self - def __next__(self): cdef PyObject *obj if not self.is_rightseq_exhausted: From 78f725ae4b539954695706852949e318fee41085 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Jun 2014 10:35:12 -0700 Subject: [PATCH 4/5] remove loose __iter__ --- cytoolz/itertoolz.pyx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index ef07e6f..2382087 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -1298,10 +1298,6 @@ cdef class _left_outer_join(_join): cdef class _inner_join(_join): - def __iter__(self): - self.matches = () - return self - def __next__(self): cdef PyObject *obj = NULL if self.i == len(self.matches): From 5867a32c3ea9cfc4155a7efb2d0b3ead349fb72a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 9 Jun 2014 19:50:09 -0700 Subject: [PATCH 5/5] add join to __init__.pxd --- cytoolz/__init__.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cytoolz/__init__.pxd b/cytoolz/__init__.pxd index 4468950..7e08db8 100644 --- a/cytoolz/__init__.pxd +++ b/cytoolz/__init__.pxd @@ -2,7 +2,7 @@ from cytoolz.itertoolz cimport ( accumulate, c_merge_sorted, cons, count, drop, get, groupby, first, frequencies, interleave, interpose, isdistinct, isiterable, iterate, last, mapcat, nth, partition, partition_all, pluck, reduceby, remove, - rest, second, sliding_window, take, take_nth, unique) + rest, second, sliding_window, take, take_nth, unique, join) from cytoolz.functoolz cimport (