diff --git a/pandas/_libs/join 2.pyx b/pandas/_libs/join 2.pyx new file mode 100644 index 0000000000000..164ed8a5c9227 --- /dev/null +++ b/pandas/_libs/join 2.pyx @@ -0,0 +1,897 @@ +cimport cython +from cython cimport Py_ssize_t +import numpy as np + +cimport numpy as cnp +from numpy cimport ( + int64_t, + intp_t, + ndarray, + uint64_t, +) + +cnp.import_array() + +from pandas._libs.algos import groupsort_indexer + +from pandas._libs.dtypes cimport ( + numeric_object_t, + numeric_t, +) + + +@cython.wraparound(False) +@cython.boundscheck(False) +def inner_join(const intp_t[:] left, const intp_t[:] right, + Py_ssize_t max_groups): + cdef: + Py_ssize_t i, j, k, count = 0 + intp_t[::1] left_sorter, right_sorter + intp_t[::1] left_count, right_count + intp_t[::1] left_indexer, right_indexer + intp_t lc, rc + Py_ssize_t left_pos = 0, right_pos = 0, position = 0 + Py_ssize_t offset + + left_sorter, left_count = groupsort_indexer(left, max_groups) + right_sorter, right_count = groupsort_indexer(right, max_groups) + + with nogil: + # First pass, determine size of result set, do not use the NA group + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc > 0 and lc > 0: + count += lc * rc + + left_indexer = np.empty(count, dtype=np.intp) + right_indexer = np.empty(count, dtype=np.intp) + + with nogil: + # exclude the NA group + left_pos = left_count[0] + right_pos = right_count[0] + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc > 0 and lc > 0: + for j in range(lc): + offset = position + j * rc + for k in range(rc): + left_indexer[offset + k] = left_pos + j + right_indexer[offset + k] = right_pos + k + position += lc * rc + left_pos += lc + right_pos += rc + + # Will overwrite left/right indexer with the result + _get_result_indexer(left_sorter, left_indexer) + _get_result_indexer(right_sorter, right_indexer) + + return np.asarray(left_indexer), np.asarray(right_indexer) + + +@cython.wraparound(False) +@cython.boundscheck(False) +def left_outer_join(const intp_t[:] left, const intp_t[:] right, + Py_ssize_t max_groups, bint sort=True): + cdef: + Py_ssize_t i, j, k, count = 0 + ndarray[intp_t] rev + intp_t[::1] left_count, right_count + intp_t[::1] left_sorter, right_sorter + intp_t[::1] left_indexer, right_indexer + intp_t lc, rc + Py_ssize_t left_pos = 0, right_pos = 0, position = 0 + Py_ssize_t offset + + left_sorter, left_count = groupsort_indexer(left, max_groups) + right_sorter, right_count = groupsort_indexer(right, max_groups) + + with nogil: + # First pass, determine size of result set, do not use the NA group + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc > 0: + count += lc * rc + else: + count += lc + + left_indexer = np.empty(count, dtype=np.intp) + right_indexer = np.empty(count, dtype=np.intp) + + with nogil: + # exclude the NA group + left_pos = left_count[0] + right_pos = right_count[0] + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc == 0: + for j in range(lc): + left_indexer[position + j] = left_pos + j + right_indexer[position + j] = -1 + position += lc + else: + for j in range(lc): + offset = position + j * rc + for k in range(rc): + left_indexer[offset + k] = left_pos + j + right_indexer[offset + k] = right_pos + k + position += lc * rc + left_pos += lc + right_pos += rc + + # Will overwrite left/right indexer with the result + _get_result_indexer(left_sorter, left_indexer) + _get_result_indexer(right_sorter, right_indexer) + + if not sort: # if not asked to sort, revert to original order + if len(left) == len(left_indexer): + # no multiple matches for any row on the left + # this is a short-cut to avoid groupsort_indexer + # otherwise, the `else` path also works in this case + rev = np.empty(len(left), dtype=np.intp) + rev.put(np.asarray(left_sorter), np.arange(len(left))) + else: + rev, _ = groupsort_indexer(left_indexer, len(left)) + + return np.asarray(left_indexer).take(rev), np.asarray(right_indexer).take(rev) + else: + return np.asarray(left_indexer), np.asarray(right_indexer) + + +@cython.wraparound(False) +@cython.boundscheck(False) +def full_outer_join(const intp_t[:] left, const intp_t[:] right, + Py_ssize_t max_groups): + cdef: + Py_ssize_t i, j, k, count = 0 + intp_t[::1] left_sorter, right_sorter + intp_t[::1] left_count, right_count + intp_t[::1] left_indexer, right_indexer + intp_t lc, rc + intp_t left_pos = 0, right_pos = 0 + Py_ssize_t offset, position = 0 + + left_sorter, left_count = groupsort_indexer(left, max_groups) + right_sorter, right_count = groupsort_indexer(right, max_groups) + + with nogil: + # First pass, determine size of result set, do not use the NA group + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc > 0 and lc > 0: + count += lc * rc + else: + count += lc + rc + + left_indexer = np.empty(count, dtype=np.intp) + right_indexer = np.empty(count, dtype=np.intp) + + with nogil: + # exclude the NA group + left_pos = left_count[0] + right_pos = right_count[0] + for i in range(1, max_groups + 1): + lc = left_count[i] + rc = right_count[i] + + if rc == 0: + for j in range(lc): + left_indexer[position + j] = left_pos + j + right_indexer[position + j] = -1 + position += lc + elif lc == 0: + for j in range(rc): + left_indexer[position + j] = -1 + right_indexer[position + j] = right_pos + j + position += rc + else: + for j in range(lc): + offset = position + j * rc + for k in range(rc): + left_indexer[offset + k] = left_pos + j + right_indexer[offset + k] = right_pos + k + position += lc * rc + left_pos += lc + right_pos += rc + + # Will overwrite left/right indexer with the result + _get_result_indexer(left_sorter, left_indexer) + _get_result_indexer(right_sorter, right_indexer) + + return np.asarray(left_indexer), np.asarray(right_indexer) + + +@cython.wraparound(False) +@cython.boundscheck(False) +cdef void _get_result_indexer(intp_t[::1] sorter, intp_t[::1] indexer) noexcept nogil: + """NOTE: overwrites indexer with the result to avoid allocating another array""" + cdef: + Py_ssize_t i, n, idx + + if len(sorter) > 0: + # cython-only equivalent to + # `res = algos.take_nd(sorter, indexer, fill_value=-1)` + n = indexer.shape[0] + for i in range(n): + idx = indexer[i] + if idx == -1: + indexer[i] = -1 + else: + indexer[i] = sorter[idx] + else: + # length-0 case + indexer[:] = -1 + + +@cython.wraparound(False) +@cython.boundscheck(False) +def ffill_indexer(const intp_t[:] indexer) -> np.ndarray: + cdef: + Py_ssize_t i, n = len(indexer) + ndarray[intp_t] result + intp_t val, last_obs + + result = np.empty(n, dtype=np.intp) + last_obs = -1 + + for i in range(n): + val = indexer[i] + if val == -1: + result[i] = last_obs + else: + result[i] = val + last_obs = val + + return result + + +# ---------------------------------------------------------------------- +# left_join_indexer, inner_join_indexer, outer_join_indexer +# ---------------------------------------------------------------------- + +# Joins on ordered, unique indices + +# right might contain non-unique values + +@cython.wraparound(False) +@cython.boundscheck(False) +def left_join_indexer_unique( + ndarray[numeric_object_t] left, + ndarray[numeric_object_t] right +): + """ + Both left and right are strictly monotonic increasing. + """ + cdef: + Py_ssize_t i, j, nleft, nright + ndarray[intp_t] indexer + numeric_object_t rval + + i = 0 + j = 0 + nleft = len(left) + nright = len(right) + + indexer = np.empty(nleft, dtype=np.intp) + while True: + if i == nleft: + break + + if j == nright: + indexer[i] = -1 + i += 1 + continue + + rval = right[j] + + while i < nleft - 1 and left[i] == rval: + indexer[i] = j + i += 1 + + if left[i] == rval: + indexer[i] = j + i += 1 + while i < nleft - 1 and left[i] == rval: + indexer[i] = j + i += 1 + j += 1 + elif left[i] > rval: + indexer[i] = -1 + j += 1 + else: + indexer[i] = -1 + i += 1 + return indexer + + +@cython.wraparound(False) +@cython.boundscheck(False) +def left_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right): + """ + Two-pass algorithm for monotonic indexes. Handles many-to-one merges. + + Both left and right are monotonic increasing, but at least one of them + is non-unique (if both were unique we'd use left_join_indexer_unique). + """ + cdef: + Py_ssize_t i, j, nright, nleft, count + numeric_object_t lval, rval + ndarray[intp_t] lindexer, rindexer + ndarray[numeric_object_t] result + + nleft = len(left) + nright = len(right) + + # First pass is to find the size 'count' of our output indexers. + i = 0 + j = 0 + count = 0 + if nleft > 0: + while i < nleft: + if j == nright: + count += nleft - i + break + + lval = left[i] + rval = right[j] + + if lval == rval: + # This block is identical across + # left_join_indexer, inner_join_indexer, outer_join_indexer + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + count += 1 + i += 1 + else: + j += 1 + + # do it again now that result size is known + + lindexer = np.empty(count, dtype=np.intp) + rindexer = np.empty(count, dtype=np.intp) + result = np.empty(count, dtype=left.dtype) + + i = 0 + j = 0 + count = 0 + if nleft > 0: + while i < nleft: + if j == nright: + while i < nleft: + lindexer[count] = i + rindexer[count] = -1 + result[count] = left[i] + i += 1 + count += 1 + break + + lval = left[i] + rval = right[j] + + if lval == rval: + lindexer[count] = i + rindexer[count] = j + result[count] = lval + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + # i.e. lval not in right; we keep for left_join_indexer + lindexer[count] = i + rindexer[count] = -1 + result[count] = lval + count += 1 + i += 1 + else: + # i.e. rval not in left; we discard for left_join_indexer + j += 1 + + return result, lindexer, rindexer + + +@cython.wraparound(False) +@cython.boundscheck(False) +def inner_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right): + """ + Two-pass algorithm for monotonic indexes. Handles many-to-one merges. + + Both left and right are monotonic increasing but not necessarily unique. + """ + cdef: + Py_ssize_t i, j, nright, nleft, count + numeric_object_t lval, rval + ndarray[intp_t] lindexer, rindexer + ndarray[numeric_object_t] result + + nleft = len(left) + nright = len(right) + + # First pass is to find the size 'count' of our output indexers. + i = 0 + j = 0 + count = 0 + if nleft > 0 and nright > 0: + while True: + if i == nleft: + break + if j == nright: + break + + lval = left[i] + rval = right[j] + if lval == rval: + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + # i.e. lval not in right; we discard for inner_indexer + i += 1 + else: + # i.e. rval not in left; we discard for inner_indexer + j += 1 + + # do it again now that result size is known + + lindexer = np.empty(count, dtype=np.intp) + rindexer = np.empty(count, dtype=np.intp) + result = np.empty(count, dtype=left.dtype) + + i = 0 + j = 0 + count = 0 + if nleft > 0 and nright > 0: + while True: + if i == nleft: + break + if j == nright: + break + + lval = left[i] + rval = right[j] + if lval == rval: + lindexer[count] = i + rindexer[count] = j + result[count] = lval + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + # i.e. lval not in right; we discard for inner_indexer + i += 1 + else: + # i.e. rval not in left; we discard for inner_indexer + j += 1 + + return result, lindexer, rindexer + + +@cython.wraparound(False) +@cython.boundscheck(False) +def outer_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right): + """ + Both left and right are monotonic increasing but not necessarily unique. + """ + cdef: + Py_ssize_t i, j, nright, nleft, count + numeric_object_t lval, rval + ndarray[intp_t] lindexer, rindexer + ndarray[numeric_object_t] result + + nleft = len(left) + nright = len(right) + + # First pass is to find the size 'count' of our output indexers. + # count will be length of left plus the number of elements of right not in + # left (counting duplicates) + i = 0 + j = 0 + count = 0 + if nleft == 0: + count = nright + elif nright == 0: + count = nleft + else: + while True: + if i == nleft: + count += nright - j + break + if j == nright: + count += nleft - i + break + + lval = left[i] + rval = right[j] + if lval == rval: + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + count += 1 + i += 1 + else: + count += 1 + j += 1 + + lindexer = np.empty(count, dtype=np.intp) + rindexer = np.empty(count, dtype=np.intp) + result = np.empty(count, dtype=left.dtype) + + # do it again, but populate the indexers / result + + i = 0 + j = 0 + count = 0 + if nleft == 0: + for j in range(nright): + lindexer[j] = -1 + rindexer[j] = j + result[j] = right[j] + elif nright == 0: + for i in range(nleft): + lindexer[i] = i + rindexer[i] = -1 + result[i] = left[i] + else: + while True: + if i == nleft: + while j < nright: + lindexer[count] = -1 + rindexer[count] = j + result[count] = right[j] + count += 1 + j += 1 + break + if j == nright: + while i < nleft: + lindexer[count] = i + rindexer[count] = -1 + result[count] = left[i] + count += 1 + i += 1 + break + + lval = left[i] + rval = right[j] + + if lval == rval: + lindexer[count] = i + rindexer[count] = j + result[count] = lval + count += 1 + if i < nleft - 1: + if j < nright - 1 and right[j + 1] == rval: + j += 1 + else: + i += 1 + if left[i] != rval: + j += 1 + elif j < nright - 1: + j += 1 + if lval != right[j]: + i += 1 + else: + # end of the road + break + elif lval < rval: + # i.e. lval not in right; we keep for outer_join_indexer + lindexer[count] = i + rindexer[count] = -1 + result[count] = lval + count += 1 + i += 1 + else: + # i.e. rval not in left; we keep for outer_join_indexer + lindexer[count] = -1 + rindexer[count] = j + result[count] = rval + count += 1 + j += 1 + + return result, lindexer, rindexer + + +# ---------------------------------------------------------------------- +# asof_join_by +# ---------------------------------------------------------------------- + +from pandas._libs.hashtable cimport ( + HashTable, + Int64HashTable, + PyObjectHashTable, + UInt64HashTable, +) + +ctypedef fused by_t: + object + int64_t + uint64_t + + +def asof_join_backward_on_X_by_Y(ndarray[numeric_t] left_values, + ndarray[numeric_t] right_values, + ndarray[by_t] left_by_values, + ndarray[by_t] right_by_values, + bint allow_exact_matches=True, + tolerance=None, + bint use_hashtable=True): + + cdef: + Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos + ndarray[intp_t] left_indexer, right_indexer + bint has_tolerance = False + numeric_t tolerance_ = 0 + numeric_t diff = 0 + HashTable hash_table + by_t by_value + + # if we are using tolerance, set our objects + if tolerance is not None: + has_tolerance = True + tolerance_ = tolerance + + left_size = len(left_values) + right_size = len(right_values) + + left_indexer = np.empty(left_size, dtype=np.intp) + right_indexer = np.empty(left_size, dtype=np.intp) + + if use_hashtable: + if by_t is object: + hash_table = PyObjectHashTable(right_size) + elif by_t is int64_t: + hash_table = Int64HashTable(right_size) + elif by_t is uint64_t: + hash_table = UInt64HashTable(right_size) + + right_pos = 0 + for left_pos in range(left_size): + # restart right_pos if it went negative in a previous iteration + if right_pos < 0: + right_pos = 0 + + # find last position in right whose value is less than left's + if allow_exact_matches: + while (right_pos < right_size and + right_values[right_pos] <= left_values[left_pos]): + if use_hashtable: + hash_table.set_item(right_by_values[right_pos], right_pos) + right_pos += 1 + else: + while (right_pos < right_size and + right_values[right_pos] < left_values[left_pos]): + if use_hashtable: + hash_table.set_item(right_by_values[right_pos], right_pos) + right_pos += 1 + right_pos -= 1 + + # save positions as the desired index + if use_hashtable: + by_value = left_by_values[left_pos] + found_right_pos = (hash_table.get_item(by_value) + if by_value in hash_table else -1) + else: + found_right_pos = right_pos + + left_indexer[left_pos] = left_pos + right_indexer[left_pos] = found_right_pos + + # if needed, verify that tolerance is met + if has_tolerance and found_right_pos != -1: + diff = left_values[left_pos] - right_values[found_right_pos] + if diff > tolerance_: + right_indexer[left_pos] = -1 + + return left_indexer, right_indexer + + +def asof_join_forward_on_X_by_Y(ndarray[numeric_t] left_values, + ndarray[numeric_t] right_values, + ndarray[by_t] left_by_values, + ndarray[by_t] right_by_values, + bint allow_exact_matches=1, + tolerance=None, + bint use_hashtable=True): + + cdef: + Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos + ndarray[intp_t] left_indexer, right_indexer + bint has_tolerance = False + numeric_t tolerance_ = 0 + numeric_t diff = 0 + HashTable hash_table + by_t by_value + + # if we are using tolerance, set our objects + if tolerance is not None: + has_tolerance = True + tolerance_ = tolerance + + left_size = len(left_values) + right_size = len(right_values) + + left_indexer = np.empty(left_size, dtype=np.intp) + right_indexer = np.empty(left_size, dtype=np.intp) + + if use_hashtable: + if by_t is object: + hash_table = PyObjectHashTable(right_size) + elif by_t is int64_t: + hash_table = Int64HashTable(right_size) + elif by_t is uint64_t: + hash_table = UInt64HashTable(right_size) + + right_pos = right_size - 1 + for left_pos in range(left_size - 1, -1, -1): + # restart right_pos if it went over in a previous iteration + if right_pos == right_size: + right_pos = right_size - 1 + + # find first position in right whose value is greater than left's + if allow_exact_matches: + while (right_pos >= 0 and + right_values[right_pos] >= left_values[left_pos]): + if use_hashtable: + hash_table.set_item(right_by_values[right_pos], right_pos) + right_pos -= 1 + else: + while (right_pos >= 0 and + right_values[right_pos] > left_values[left_pos]): + if use_hashtable: + hash_table.set_item(right_by_values[right_pos], right_pos) + right_pos -= 1 + right_pos += 1 + + # save positions as the desired index + if use_hashtable: + by_value = left_by_values[left_pos] + found_right_pos = (hash_table.get_item(by_value) + if by_value in hash_table else -1) + else: + found_right_pos = (right_pos + if right_pos != right_size else -1) + + left_indexer[left_pos] = left_pos + right_indexer[left_pos] = found_right_pos + + # if needed, verify that tolerance is met + if has_tolerance and found_right_pos != -1: + diff = right_values[found_right_pos] - left_values[left_pos] + if diff > tolerance_: + right_indexer[left_pos] = -1 + + return left_indexer, right_indexer + + +def asof_join_nearest_on_X_by_Y(ndarray[numeric_t] left_values, + ndarray[numeric_t] right_values, + ndarray[by_t] left_by_values, + ndarray[by_t] right_by_values, + bint allow_exact_matches=True, + tolerance=None, + bint use_hashtable=True): + + cdef: + ndarray[intp_t] bli, bri, fli, fri + + ndarray[intp_t] left_indexer, right_indexer + Py_ssize_t left_size, i + numeric_t bdiff, fdiff + + # search both forward and backward + # TODO(cython3): + # Bug in beta1 preventing Cython from choosing + # right specialization when one fused memview is None + # Doesn't matter what type we choose + # (nothing happens anyways since it is None) + # GH 51640 + if left_by_values is not None and left_by_values.dtype != object: + by_dtype = f"{left_by_values.dtype}_t" + else: + by_dtype = object + bli, bri = asof_join_backward_on_X_by_Y[f"{left_values.dtype}_t", by_dtype]( + left_values, + right_values, + left_by_values, + right_by_values, + allow_exact_matches, + tolerance, + use_hashtable + ) + fli, fri = asof_join_forward_on_X_by_Y[f"{left_values.dtype}_t", by_dtype]( + left_values, + right_values, + left_by_values, + right_by_values, + allow_exact_matches, + tolerance, + use_hashtable + ) + + # choose the smaller timestamp + left_size = len(left_values) + left_indexer = np.empty(left_size, dtype=np.intp) + right_indexer = np.empty(left_size, dtype=np.intp) + + for i in range(len(bri)): + # choose timestamp from right with smaller difference + if bri[i] != -1 and fri[i] != -1: + bdiff = left_values[bli[i]] - right_values[bri[i]] + fdiff = right_values[fri[i]] - left_values[fli[i]] + right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i] + else: + right_indexer[i] = bri[i] if bri[i] != -1 else fri[i] + left_indexer[i] = bli[i] + + return left_indexer, right_indexer diff --git a/pandas/core/computation/pytables 2.py b/pandas/core/computation/pytables 2.py new file mode 100644 index 0000000000000..e836ea20ede83 --- /dev/null +++ b/pandas/core/computation/pytables 2.py @@ -0,0 +1,654 @@ +""" manage PyTables query interface via Expressions """ +from __future__ import annotations + +import ast +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from pandas._libs.tslibs import ( + Timedelta, + Timestamp, +) +from pandas.errors import UndefinedVariableError + +from pandas.core.dtypes.common import is_list_like + +import pandas.core.common as com +from pandas.core.computation import ( + expr, + ops, + scope as _scope, +) +from pandas.core.computation.common import ensure_decoded +from pandas.core.computation.expr import BaseExprVisitor +from pandas.core.computation.ops import is_term +from pandas.core.construction import extract_array +from pandas.core.indexes.base import Index + +from pandas.io.formats.printing import ( + pprint_thing, + pprint_thing_encoded, +) + +if TYPE_CHECKING: + from pandas._typing import npt + + +class PyTablesScope(_scope.Scope): + __slots__ = ("queryables",) + + queryables: dict[str, Any] + + def __init__( + self, + level: int, + global_dict=None, + local_dict=None, + queryables: dict[str, Any] | None = None, + ) -> None: + super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict) + self.queryables = queryables or {} + + +class Term(ops.Term): + env: PyTablesScope + + def __new__(cls, name, env, side=None, encoding=None): + if isinstance(name, str): + klass = cls + else: + klass = Constant + return object.__new__(klass) + + def __init__(self, name, env: PyTablesScope, side=None, encoding=None) -> None: + super().__init__(name, env, side=side, encoding=encoding) + + def _resolve_name(self): + # must be a queryables + if self.side == "left": + # Note: The behavior of __new__ ensures that self.name is a str here + if self.name not in self.env.queryables: + raise NameError(f"name {repr(self.name)} is not defined") + return self.name + + # resolve the rhs (and allow it to be None) + try: + return self.env.resolve(self.name, is_local=False) + except UndefinedVariableError: + return self.name + + # read-only property overwriting read/write property + @property # type: ignore[misc] + def value(self): + return self._value + + +class Constant(Term): + def __init__(self, value, env: PyTablesScope, side=None, encoding=None) -> None: + assert isinstance(env, PyTablesScope), type(env) + super().__init__(value, env, side=side, encoding=encoding) + + def _resolve_name(self): + return self._name + + +class BinOp(ops.BinOp): + _max_selectors = 31 + + op: str + queryables: dict[str, Any] + condition: str | None + + def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], encoding) -> None: + super().__init__(op, lhs, rhs) + self.queryables = queryables + self.encoding = encoding + self.condition = None + + def _disallow_scalar_only_bool_ops(self) -> None: + pass + + def prune(self, klass): + def pr(left, right): + """create and return a new specialized BinOp from myself""" + if left is None: + return right + elif right is None: + return left + + k = klass + if isinstance(left, ConditionBinOp): + if isinstance(right, ConditionBinOp): + k = JointConditionBinOp + elif isinstance(left, k): + return left + elif isinstance(right, k): + return right + + elif isinstance(left, FilterBinOp): + if isinstance(right, FilterBinOp): + k = JointFilterBinOp + elif isinstance(left, k): + return left + elif isinstance(right, k): + return right + + return k( + self.op, left, right, queryables=self.queryables, encoding=self.encoding + ).evaluate() + + left, right = self.lhs, self.rhs + + if is_term(left) and is_term(right): + res = pr(left.value, right.value) + elif not is_term(left) and is_term(right): + res = pr(left.prune(klass), right.value) + elif is_term(left) and not is_term(right): + res = pr(left.value, right.prune(klass)) + elif not (is_term(left) or is_term(right)): + res = pr(left.prune(klass), right.prune(klass)) + + return res + + def conform(self, rhs): + """inplace conform rhs""" + if not is_list_like(rhs): + rhs = [rhs] + if isinstance(rhs, np.ndarray): + rhs = rhs.ravel() + return rhs + + @property + def is_valid(self) -> bool: + """return True if this is a valid field""" + return self.lhs in self.queryables + + @property + def is_in_table(self) -> bool: + """ + return True if this is a valid column name for generation (e.g. an + actual column in the table) + """ + return self.queryables.get(self.lhs) is not None + + @property + def kind(self): + """the kind of my field""" + return getattr(self.queryables.get(self.lhs), "kind", None) + + @property + def meta(self): + """the meta of my field""" + return getattr(self.queryables.get(self.lhs), "meta", None) + + @property + def metadata(self): + """the metadata of my field""" + return getattr(self.queryables.get(self.lhs), "metadata", None) + + def generate(self, v) -> str: + """create and return the op string for this TermValue""" + val = v.tostring(self.encoding) + return f"({self.lhs} {self.op} {val})" + + def convert_value(self, v) -> TermValue: + """ + convert the expression that is in the term to something that is + accepted by pytables + """ + + def stringify(value): + if self.encoding is not None: + return pprint_thing_encoded(value, encoding=self.encoding) + return pprint_thing(value) + + kind = ensure_decoded(self.kind) + meta = ensure_decoded(self.meta) + if kind in ("datetime64", "datetime"): + if isinstance(v, (int, float)): + v = stringify(v) + v = ensure_decoded(v) + v = Timestamp(v).as_unit("ns") + if v.tz is not None: + v = v.tz_convert("UTC") + return TermValue(v, v._value, kind) + elif kind in ("timedelta64", "timedelta"): + if isinstance(v, str): + v = Timedelta(v) + else: + v = Timedelta(v, unit="s") + v = v.as_unit("ns")._value + return TermValue(int(v), v, kind) + elif meta == "category": + metadata = extract_array(self.metadata, extract_numpy=True) + result: npt.NDArray[np.intp] | np.intp | int + if v not in metadata: + result = -1 + else: + result = metadata.searchsorted(v, side="left") + return TermValue(result, result, "integer") + elif kind == "integer": + from decimal import ( + Decimal, + InvalidOperation, + ) + + try: + v_dec = Decimal(v) + v = int(v_dec.to_integral_exact(rounding="ROUND_HALF_EVEN")) + except InvalidOperation: + raise ValueError("could not convert string to ") + return TermValue(v, v, kind) + elif kind == "float": + v = float(v) + return TermValue(v, v, kind) + elif kind == "bool": + if isinstance(v, str): + v = v.strip().lower() not in [ + "false", + "f", + "no", + "n", + "none", + "0", + "[]", + "{}", + "", + ] + else: + v = bool(v) + return TermValue(v, v, kind) + elif isinstance(v, str): + # string quoting + return TermValue(v, stringify(v), "string") + else: + raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column") + + def convert_values(self) -> None: + pass + + +class FilterBinOp(BinOp): + filter: tuple[Any, Any, Index] | None = None + + def __repr__(self) -> str: + if self.filter is None: + return "Filter: Not Initialized" + return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]") + + def invert(self): + """invert the filter""" + if self.filter is not None: + self.filter = ( + self.filter[0], + self.generate_filter_op(invert=True), + self.filter[2], + ) + return self + + def format(self): + """return the actual filter format""" + return [self.filter] + + def evaluate(self): + if not self.is_valid: + raise ValueError(f"query term is not valid [{self}]") + + rhs = self.conform(self.rhs) + values = list(rhs) + + if self.is_in_table: + # if too many values to create the expression, use a filter instead + if self.op in ["==", "!="] and len(values) > self._max_selectors: + filter_op = self.generate_filter_op() + self.filter = (self.lhs, filter_op, Index(values)) + + return self + return None + + # equality conditions + if self.op in ["==", "!="]: + filter_op = self.generate_filter_op() + self.filter = (self.lhs, filter_op, Index(values)) + + else: + raise TypeError( + f"passing a filterable condition to a non-table indexer [{self}]" + ) + + return self + + def generate_filter_op(self, invert: bool = False): + if (self.op == "!=" and not invert) or (self.op == "==" and invert): + return lambda axis, vals: ~axis.isin(vals) + else: + return lambda axis, vals: axis.isin(vals) + + +class JointFilterBinOp(FilterBinOp): + def format(self): + raise NotImplementedError("unable to collapse Joint Filters") + + def evaluate(self): + return self + + +class ConditionBinOp(BinOp): + def __repr__(self) -> str: + return pprint_thing(f"[Condition : [{self.condition}]]") + + def invert(self): + """invert the condition""" + # if self.condition is not None: + # self.condition = "~(%s)" % self.condition + # return self + raise NotImplementedError( + "cannot use an invert condition when passing to numexpr" + ) + + def format(self): + """return the actual ne format""" + return self.condition + + def evaluate(self): + if not self.is_valid: + raise ValueError(f"query term is not valid [{self}]") + + # convert values if we are in the table + if not self.is_in_table: + return None + + rhs = self.conform(self.rhs) + values = [self.convert_value(v) for v in rhs] + + # equality conditions + if self.op in ["==", "!="]: + # too many values to create the expression? + if len(values) <= self._max_selectors: + vs = [self.generate(v) for v in values] + self.condition = f"({' | '.join(vs)})" + + # use a filter after reading + else: + return None + else: + self.condition = self.generate(values[0]) + + return self + + +class JointConditionBinOp(ConditionBinOp): + def evaluate(self): + self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})" + return self + + +class UnaryOp(ops.UnaryOp): + def prune(self, klass): + if self.op != "~": + raise NotImplementedError("UnaryOp only support invert type ops") + + operand = self.operand + operand = operand.prune(klass) + + if operand is not None and ( + issubclass(klass, ConditionBinOp) + and operand.condition is not None + or not issubclass(klass, ConditionBinOp) + and issubclass(klass, FilterBinOp) + and operand.filter is not None + ): + return operand.invert() + return None + + +class PyTablesExprVisitor(BaseExprVisitor): + const_type = Constant + term_type = Term + + def __init__(self, env, engine, parser, **kwargs) -> None: + super().__init__(env, engine, parser) + for bin_op in self.binary_ops: + bin_node = self.binary_op_nodes_map[bin_op] + setattr( + self, + f"visit_{bin_node}", + lambda node, bin_op=bin_op: partial(BinOp, bin_op, **kwargs), + ) + + def visit_UnaryOp(self, node, **kwargs): + if isinstance(node.op, (ast.Not, ast.Invert)): + return UnaryOp("~", self.visit(node.operand)) + elif isinstance(node.op, ast.USub): + return self.const_type(-self.visit(node.operand).value, self.env) + elif isinstance(node.op, ast.UAdd): + raise NotImplementedError("Unary addition not supported") + + def visit_Index(self, node, **kwargs): + return self.visit(node.value).value + + def visit_Assign(self, node, **kwargs): + cmpr = ast.Compare( + ops=[ast.Eq()], left=node.targets[0], comparators=[node.value] + ) + return self.visit(cmpr) + + def visit_Subscript(self, node, **kwargs): + # only allow simple subscripts + + value = self.visit(node.value) + slobj = self.visit(node.slice) + try: + value = value.value + except AttributeError: + pass + + if isinstance(slobj, Term): + # In py39 np.ndarray lookups with Term containing int raise + slobj = slobj.value + + try: + return self.const_type(value[slobj], self.env) + except TypeError as err: + raise ValueError( + f"cannot subscript {repr(value)} with {repr(slobj)}" + ) from err + + def visit_Attribute(self, node, **kwargs): + attr = node.attr + value = node.value + + ctx = type(node.ctx) + if ctx == ast.Load: + # resolve the value + resolved = self.visit(value) + + # try to get the value to see if we are another expression + try: + resolved = resolved.value + except AttributeError: + pass + + try: + return self.term_type(getattr(resolved, attr), self.env) + except AttributeError: + # something like datetime.datetime where scope is overridden + if isinstance(value, ast.Name) and value.id == attr: + return resolved + + raise ValueError(f"Invalid Attribute context {ctx.__name__}") + + def translate_In(self, op): + return ast.Eq() if isinstance(op, ast.In) else op + + def _rewrite_membership_op(self, node, left, right): + return self.visit(node.op), node.op, left, right + + +def _validate_where(w): + """ + Validate that the where statement is of the right type. + + The type may either be String, Expr, or list-like of Exprs. + + Parameters + ---------- + w : String term expression, Expr, or list-like of Exprs. + + Returns + ------- + where : The original where clause if the check was successful. + + Raises + ------ + TypeError : An invalid data type was passed in for w (e.g. dict). + """ + if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)): + raise TypeError( + "where must be passed as a string, PyTablesExpr, " + "or list-like of PyTablesExpr" + ) + + return w + + +class PyTablesExpr(expr.Expr): + """ + Hold a pytables-like expression, comprised of possibly multiple 'terms'. + + Parameters + ---------- + where : string term expression, PyTablesExpr, or list-like of PyTablesExprs + queryables : a "kinds" map (dict of column name -> kind), or None if column + is non-indexable + encoding : an encoding that will encode the query terms + + Returns + ------- + a PyTablesExpr object + + Examples + -------- + 'index>=date' + "columns=['A', 'D']" + 'columns=A' + 'columns==A' + "~(columns=['A','B'])" + 'index>df.index[3] & string="bar"' + '(index>df.index[3] & index<=df.index[6]) | string="bar"' + "ts>=Timestamp('2012-02-01')" + "major_axis>=20130101" + """ + + _visitor: PyTablesExprVisitor | None + env: PyTablesScope + expr: str + + def __init__( + self, + where, + queryables: dict[str, Any] | None = None, + encoding=None, + scope_level: int = 0, + ) -> None: + where = _validate_where(where) + + self.encoding = encoding + self.condition = None + self.filter = None + self.terms = None + self._visitor = None + + # capture the environment if needed + local_dict: _scope.DeepChainMap[Any, Any] | None = None + + if isinstance(where, PyTablesExpr): + local_dict = where.env.scope + _where = where.expr + + elif is_list_like(where): + where = list(where) + for idx, w in enumerate(where): + if isinstance(w, PyTablesExpr): + local_dict = w.env.scope + else: + where[idx] = _validate_where(w) + _where = " & ".join([f"({w})" for w in com.flatten(where)]) + else: + # _validate_where ensures we otherwise have a string + _where = where + + self.expr = _where + self.env = PyTablesScope(scope_level + 1, local_dict=local_dict) + + if queryables is not None and isinstance(self.expr, str): + self.env.queryables.update(queryables) + self._visitor = PyTablesExprVisitor( + self.env, + queryables=queryables, + parser="pytables", + engine="pytables", + encoding=encoding, + ) + self.terms = self.parse() + + def __repr__(self) -> str: + if self.terms is not None: + return pprint_thing(self.terms) + return pprint_thing(self.expr) + + def evaluate(self): + """create and return the numexpr condition and filter""" + try: + self.condition = self.terms.prune(ConditionBinOp) + except AttributeError as err: + raise ValueError( + f"cannot process expression [{self.expr}], [{self}] " + "is not a valid condition" + ) from err + try: + self.filter = self.terms.prune(FilterBinOp) + except AttributeError as err: + raise ValueError( + f"cannot process expression [{self.expr}], [{self}] " + "is not a valid filter" + ) from err + + return self.condition, self.filter + + +class TermValue: + """hold a term value the we use to construct a condition/filter""" + + def __init__(self, value, converted, kind: str) -> None: + assert isinstance(kind, str), kind + self.value = value + self.converted = converted + self.kind = kind + + def tostring(self, encoding) -> str: + """quote the string if not encoded else encode and return""" + if self.kind == "string": + if encoding is not None: + return str(self.converted) + return f'"{self.converted}"' + elif self.kind == "float": + # python 2 str(float) is not always + # round-trippable so use repr() + return repr(self.converted) + return str(self.converted) + + +def maybe_expression(s) -> bool: + """loose checking if s is a pytables-acceptable expression""" + if not isinstance(s, str): + return False + operations = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",) + + # make sure we have an op at least + return any(op in s for op in operations) diff --git a/pandas/core/groupby/grouper 2.py b/pandas/core/groupby/grouper 2.py new file mode 100644 index 0000000000000..764b74f81e7ef --- /dev/null +++ b/pandas/core/groupby/grouper 2.py @@ -0,0 +1,1065 @@ +""" +Provide user facing operators for doing the split part of the +split-apply-combine paradigm. +""" +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + final, +) +import warnings + +import numpy as np + +from pandas._config import using_copy_on_write + +from pandas._libs import lib +from pandas.errors import InvalidIndexError +from pandas.util._decorators import cache_readonly +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import ( + is_list_like, + is_scalar, +) +from pandas.core.dtypes.dtypes import CategoricalDtype + +from pandas.core import algorithms +from pandas.core.arrays import ( + Categorical, + ExtensionArray, +) +import pandas.core.common as com +from pandas.core.frame import DataFrame +from pandas.core.groupby import ops +from pandas.core.groupby.categorical import recode_for_groupby +from pandas.core.indexes.api import ( + CategoricalIndex, + Index, + MultiIndex, +) +from pandas.core.series import Series + +from pandas.io.formats.printing import pprint_thing + +if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Iterator, + ) + + from pandas._typing import ( + ArrayLike, + Axis, + NDFrameT, + npt, + ) + + from pandas.core.generic import NDFrame + + +class Grouper: + """ + A Grouper allows the user to specify a groupby instruction for an object. + + This specification will select a column via the key parameter, or if the + level and/or axis parameters are given, a level of the index of the target + object. + + If `axis` and/or `level` are passed as keywords to both `Grouper` and + `groupby`, the values passed to `Grouper` take precedence. + + Parameters + ---------- + key : str, defaults to None + Groupby key, which selects the grouping column of the target. + level : name/number, defaults to None + The level for the target index. + freq : str / frequency object, defaults to None + This will groupby the specified frequency if the target selection + (via key or level) is a datetime-like object. For full specification + of available frequencies, please see `here + `_. + axis : str, int, defaults to 0 + Number/name of the axis. + sort : bool, default to False + Whether to sort the resulting labels. + closed : {'left' or 'right'} + Closed end of interval. Only when `freq` parameter is passed. + label : {'left' or 'right'} + Interval boundary to use for labeling. + Only when `freq` parameter is passed. + convention : {'start', 'end', 'e', 's'} + If grouper is PeriodIndex and `freq` parameter is passed. + + origin : Timestamp or str, default 'start_day' + The timestamp on which to adjust the grouping. The timezone of origin must + match the timezone of the index. + If string, must be one of the following: + + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + + .. versionadded:: 1.3.0 + + offset : Timedelta or str, default is None + An offset timedelta added to the origin. + + dropna : bool, default True + If True, and if group keys contain NA values, NA values together with + row/column will be dropped. If False, NA values will also be treated as + the key in groups. + + .. versionadded:: 1.2.0 + + Returns + ------- + Grouper or pandas.api.typing.TimeGrouper + A TimeGrouper is returned if ``freq`` is not ``None``. Otherwise, a Grouper + is returned. + + Examples + -------- + Syntactic sugar for ``df.groupby('A')`` + + >>> df = pd.DataFrame( + ... { + ... "Animal": ["Falcon", "Parrot", "Falcon", "Falcon", "Parrot"], + ... "Speed": [100, 5, 200, 300, 15], + ... } + ... ) + >>> df + Animal Speed + 0 Falcon 100 + 1 Parrot 5 + 2 Falcon 200 + 3 Falcon 300 + 4 Parrot 15 + >>> df.groupby(pd.Grouper(key="Animal")).mean() + Speed + Animal + Falcon 200.0 + Parrot 10.0 + + Specify a resample operation on the column 'Publish date' + + >>> df = pd.DataFrame( + ... { + ... "Publish date": [ + ... pd.Timestamp("2000-01-02"), + ... pd.Timestamp("2000-01-02"), + ... pd.Timestamp("2000-01-09"), + ... pd.Timestamp("2000-01-16") + ... ], + ... "ID": [0, 1, 2, 3], + ... "Price": [10, 20, 30, 40] + ... } + ... ) + >>> df + Publish date ID Price + 0 2000-01-02 0 10 + 1 2000-01-02 1 20 + 2 2000-01-09 2 30 + 3 2000-01-16 3 40 + >>> df.groupby(pd.Grouper(key="Publish date", freq="1W")).mean() + ID Price + Publish date + 2000-01-02 0.5 15.0 + 2000-01-09 2.0 30.0 + 2000-01-16 3.0 40.0 + + If you want to adjust the start of the bins based on a fixed timestamp: + + >>> start, end = '2000-10-01 23:30:00', '2000-10-02 00:30:00' + >>> rng = pd.date_range(start, end, freq='7min') + >>> ts = pd.Series(np.arange(len(rng)) * 3, index=rng) + >>> ts + 2000-10-01 23:30:00 0 + 2000-10-01 23:37:00 3 + 2000-10-01 23:44:00 6 + 2000-10-01 23:51:00 9 + 2000-10-01 23:58:00 12 + 2000-10-02 00:05:00 15 + 2000-10-02 00:12:00 18 + 2000-10-02 00:19:00 21 + 2000-10-02 00:26:00 24 + Freq: 7T, dtype: int64 + + >>> ts.groupby(pd.Grouper(freq='17min')).sum() + 2000-10-01 23:14:00 0 + 2000-10-01 23:31:00 9 + 2000-10-01 23:48:00 21 + 2000-10-02 00:05:00 54 + 2000-10-02 00:22:00 24 + Freq: 17T, dtype: int64 + + >>> ts.groupby(pd.Grouper(freq='17min', origin='epoch')).sum() + 2000-10-01 23:18:00 0 + 2000-10-01 23:35:00 18 + 2000-10-01 23:52:00 27 + 2000-10-02 00:09:00 39 + 2000-10-02 00:26:00 24 + Freq: 17T, dtype: int64 + + >>> ts.groupby(pd.Grouper(freq='17min', origin='2000-01-01')).sum() + 2000-10-01 23:24:00 3 + 2000-10-01 23:41:00 15 + 2000-10-01 23:58:00 45 + 2000-10-02 00:15:00 45 + Freq: 17T, dtype: int64 + + If you want to adjust the start of the bins with an `offset` Timedelta, the two + following lines are equivalent: + + >>> ts.groupby(pd.Grouper(freq='17min', origin='start')).sum() + 2000-10-01 23:30:00 9 + 2000-10-01 23:47:00 21 + 2000-10-02 00:04:00 54 + 2000-10-02 00:21:00 24 + Freq: 17T, dtype: int64 + + >>> ts.groupby(pd.Grouper(freq='17min', offset='23h30min')).sum() + 2000-10-01 23:30:00 9 + 2000-10-01 23:47:00 21 + 2000-10-02 00:04:00 54 + 2000-10-02 00:21:00 24 + Freq: 17T, dtype: int64 + + To replace the use of the deprecated `base` argument, you can now use `offset`, + in this example it is equivalent to have `base=2`: + + >>> ts.groupby(pd.Grouper(freq='17min', offset='2min')).sum() + 2000-10-01 23:16:00 0 + 2000-10-01 23:33:00 9 + 2000-10-01 23:50:00 36 + 2000-10-02 00:07:00 39 + 2000-10-02 00:24:00 24 + Freq: 17T, dtype: int64 + """ + + sort: bool + dropna: bool + _gpr_index: Index | None + _grouper: Index | None + + _attributes: tuple[str, ...] = ("key", "level", "freq", "axis", "sort", "dropna") + + def __new__(cls, *args, **kwargs): + if kwargs.get("freq") is not None: + from pandas.core.resample import TimeGrouper + + cls = TimeGrouper + return super().__new__(cls) + + def __init__( + self, + key=None, + level=None, + freq=None, + axis: Axis | lib.NoDefault = lib.no_default, + sort: bool = False, + dropna: bool = True, + ) -> None: + if type(self) is Grouper: + # i.e. not TimeGrouper + if axis is not lib.no_default: + warnings.warn( + "Grouper axis keyword is deprecated and will be removed in a " + "future version. To group on axis=1, use obj.T.groupby(...) " + "instead", + FutureWarning, + stacklevel=find_stack_level(), + ) + else: + axis = 0 + if axis is lib.no_default: + axis = 0 + + self.key = key + self.level = level + self.freq = freq + self.axis = axis + self.sort = sort + self.dropna = dropna + + self._grouper_deprecated = None + self._indexer_deprecated = None + self._obj_deprecated = None + self._gpr_index = None + self.binner = None + self._grouper = None + self._indexer = None + + def _get_grouper( + self, obj: NDFrameT, validate: bool = True + ) -> tuple[ops.BaseGrouper, NDFrameT]: + """ + Parameters + ---------- + obj : Series or DataFrame + validate : bool, default True + if True, validate the grouper + + Returns + ------- + a tuple of grouper, obj (possibly sorted) + """ + obj, _, _ = self._set_grouper(obj) + grouper, _, obj = get_grouper( + obj, + [self.key], + axis=self.axis, + level=self.level, + sort=self.sort, + validate=validate, + dropna=self.dropna, + ) + # Without setting this, subsequent lookups to .groups raise + # error: Incompatible types in assignment (expression has type "BaseGrouper", + # variable has type "None") + self._grouper_deprecated = grouper # type: ignore[assignment] + + return grouper, obj + + @final + def _set_grouper( + self, obj: NDFrame, sort: bool = False, *, gpr_index: Index | None = None + ): + """ + given an object and the specifications, setup the internal grouper + for this particular specification + + Parameters + ---------- + obj : Series or DataFrame + sort : bool, default False + whether the resulting grouper should be sorted + gpr_index : Index or None, default None + + Returns + ------- + NDFrame + Index + np.ndarray[np.intp] | None + """ + assert obj is not None + + indexer = None + + if self.key is not None and self.level is not None: + raise ValueError("The Grouper cannot specify both a key and a level!") + + # Keep self._grouper value before overriding + if self._grouper is None: + # TODO: What are we assuming about subsequent calls? + self._grouper = gpr_index + self._indexer = self._indexer_deprecated + + # the key must be a valid info item + if self.key is not None: + key = self.key + # The 'on' is already defined + if getattr(gpr_index, "name", None) == key and isinstance(obj, Series): + # Sometimes self._grouper will have been resorted while + # obj has not. In this case there is a mismatch when we + # call self._grouper.take(obj.index) so we need to undo the sorting + # before we call _grouper.take. + assert self._grouper is not None + if self._indexer is not None: + reverse_indexer = self._indexer.argsort() + unsorted_ax = self._grouper.take(reverse_indexer) + ax = unsorted_ax.take(obj.index) + else: + ax = self._grouper.take(obj.index) + else: + if key not in obj._info_axis: + raise KeyError(f"The grouper name {key} is not found") + ax = Index(obj[key], name=key) + + else: + ax = obj._get_axis(self.axis) + if self.level is not None: + level = self.level + + # if a level is given it must be a mi level or + # equivalent to the axis name + if isinstance(ax, MultiIndex): + level = ax._get_level_number(level) + ax = Index(ax._get_level_values(level), name=ax.names[level]) + + else: + if level not in (0, ax.name): + raise ValueError(f"The level {level} is not valid") + + # possibly sort + if (self.sort or sort) and not ax.is_monotonic_increasing: + # use stable sort to support first, last, nth + # TODO: why does putting na_position="first" fix datetimelike cases? + indexer = self._indexer_deprecated = ax.array.argsort( + kind="mergesort", na_position="first" + ) + ax = ax.take(indexer) + obj = obj.take(indexer, axis=self.axis) + + # error: Incompatible types in assignment (expression has type + # "NDFrameT", variable has type "None") + self._obj_deprecated = obj # type: ignore[assignment] + self._gpr_index = ax + return obj, ax, indexer + + @final + @property + def ax(self) -> Index: + warnings.warn( + f"{type(self).__name__}.ax is deprecated and will be removed in a " + "future version. Use Resampler.ax instead", + FutureWarning, + stacklevel=find_stack_level(), + ) + index = self._gpr_index + if index is None: + raise ValueError("_set_grouper must be called before ax is accessed") + return index + + @final + @property + def indexer(self): + warnings.warn( + f"{type(self).__name__}.indexer is deprecated and will be removed " + "in a future version. Use Resampler.indexer instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return self._indexer_deprecated + + @final + @property + def obj(self): + warnings.warn( + f"{type(self).__name__}.obj is deprecated and will be removed " + "in a future version. Use GroupBy.indexer instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return self._obj_deprecated + + @final + @property + def grouper(self): + warnings.warn( + f"{type(self).__name__}.grouper is deprecated and will be removed " + "in a future version. Use GroupBy.grouper instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return self._grouper_deprecated + + @final + @property + def groups(self): + warnings.warn( + f"{type(self).__name__}.groups is deprecated and will be removed " + "in a future version. Use GroupBy.groups instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + # error: "None" has no attribute "groups" + return self._grouper_deprecated.groups # type: ignore[attr-defined] + + @final + def __repr__(self) -> str: + attrs_list = ( + f"{attr_name}={repr(getattr(self, attr_name))}" + for attr_name in self._attributes + if getattr(self, attr_name) is not None + ) + attrs = ", ".join(attrs_list) + cls_name = type(self).__name__ + return f"{cls_name}({attrs})" + + +@final +class Grouping: + """ + Holds the grouping information for a single key + + Parameters + ---------- + index : Index + grouper : + obj : DataFrame or Series + name : Label + level : + observed : bool, default False + If we are a Categorical, use the observed values + in_axis : if the Grouping is a column in self.obj and hence among + Groupby.exclusions list + dropna : bool, default True + Whether to drop NA groups. + uniques : Array-like, optional + When specified, will be used for unique values. Enables including empty groups + in the result for a BinGrouper. Must not contain duplicates. + + Attributes + ------- + indices : dict + Mapping of {group -> index_list} + codes : ndarray + Group codes + group_index : Index or None + unique groups + groups : dict + Mapping of {group -> label_list} + """ + + _codes: npt.NDArray[np.signedinteger] | None = None + _group_index: Index | None = None + _all_grouper: Categorical | None + _orig_cats: Index | None + _index: Index + + def __init__( + self, + index: Index, + grouper=None, + obj: NDFrame | None = None, + level=None, + sort: bool = True, + observed: bool = False, + in_axis: bool = False, + dropna: bool = True, + uniques: ArrayLike | None = None, + ) -> None: + self.level = level + self._orig_grouper = grouper + grouping_vector = _convert_grouper(index, grouper) + self._all_grouper = None + self._orig_cats = None + self._index = index + self._sort = sort + self.obj = obj + self._observed = observed + self.in_axis = in_axis + self._dropna = dropna + self._uniques = uniques + + # we have a single grouper which may be a myriad of things, + # some of which are dependent on the passing in level + + ilevel = self._ilevel + if ilevel is not None: + # In extant tests, the new self.grouping_vector matches + # `index.get_level_values(ilevel)` whenever + # mapper is None and isinstance(index, MultiIndex) + if isinstance(index, MultiIndex): + index_level = index.get_level_values(ilevel) + else: + index_level = index + + if grouping_vector is None: + grouping_vector = index_level + else: + mapper = grouping_vector + grouping_vector = index_level.map(mapper) + + # a passed Grouper like, directly get the grouper in the same way + # as single grouper groupby, use the group_info to get codes + elif isinstance(grouping_vector, Grouper): + # get the new grouper; we already have disambiguated + # what key/level refer to exactly, don't need to + # check again as we have by this point converted these + # to an actual value (rather than a pd.Grouper) + assert self.obj is not None # for mypy + newgrouper, newobj = grouping_vector._get_grouper(self.obj, validate=False) + self.obj = newobj + + if isinstance(newgrouper, ops.BinGrouper): + # TODO: can we unwrap this and get a tighter typing + # for self.grouping_vector? + grouping_vector = newgrouper + else: + # ops.BaseGrouper + # TODO: 2023-02-03 no test cases with len(newgrouper.groupings) > 1. + # If that were to occur, would we be throwing out information? + # error: Cannot determine type of "grouping_vector" [has-type] + ng = newgrouper.groupings[0].grouping_vector # type: ignore[has-type] + # use Index instead of ndarray so we can recover the name + grouping_vector = Index(ng, name=newgrouper.result_index.name) + + elif not isinstance( + grouping_vector, (Series, Index, ExtensionArray, np.ndarray) + ): + # no level passed + if getattr(grouping_vector, "ndim", 1) != 1: + t = str(type(grouping_vector)) + raise ValueError(f"Grouper for '{t}' not 1-dimensional") + + grouping_vector = index.map(grouping_vector) + + if not ( + hasattr(grouping_vector, "__len__") + and len(grouping_vector) == len(index) + ): + grper = pprint_thing(grouping_vector) + errmsg = ( + "Grouper result violates len(labels) == " + f"len(data)\nresult: {grper}" + ) + raise AssertionError(errmsg) + + if isinstance(grouping_vector, np.ndarray): + if grouping_vector.dtype.kind in "mM": + # if we have a date/time-like grouper, make sure that we have + # Timestamps like + # TODO 2022-10-08 we only have one test that gets here and + # values are already in nanoseconds in that case. + grouping_vector = Series(grouping_vector).to_numpy() + elif isinstance(getattr(grouping_vector, "dtype", None), CategoricalDtype): + # a passed Categorical + self._orig_cats = grouping_vector.categories + grouping_vector, self._all_grouper = recode_for_groupby( + grouping_vector, sort, observed + ) + + self.grouping_vector = grouping_vector + + def __repr__(self) -> str: + return f"Grouping({self.name})" + + def __iter__(self) -> Iterator: + return iter(self.indices) + + @cache_readonly + def _passed_categorical(self) -> bool: + dtype = getattr(self.grouping_vector, "dtype", None) + return isinstance(dtype, CategoricalDtype) + + @cache_readonly + def name(self) -> Hashable: + ilevel = self._ilevel + if ilevel is not None: + return self._index.names[ilevel] + + if isinstance(self._orig_grouper, (Index, Series)): + return self._orig_grouper.name + + elif isinstance(self.grouping_vector, ops.BaseGrouper): + return self.grouping_vector.result_index.name + + elif isinstance(self.grouping_vector, Index): + return self.grouping_vector.name + + # otherwise we have ndarray or ExtensionArray -> no name + return None + + @cache_readonly + def _ilevel(self) -> int | None: + """ + If necessary, converted index level name to index level position. + """ + level = self.level + if level is None: + return None + if not isinstance(level, int): + index = self._index + if level not in index.names: + raise AssertionError(f"Level {level} not in index") + return index.names.index(level) + return level + + @property + def ngroups(self) -> int: + return len(self.group_index) + + @cache_readonly + def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]: + # we have a list of groupers + if isinstance(self.grouping_vector, ops.BaseGrouper): + return self.grouping_vector.indices + + values = Categorical(self.grouping_vector) + return values._reverse_indexer() + + @property + def codes(self) -> npt.NDArray[np.signedinteger]: + return self._codes_and_uniques[0] + + @cache_readonly + def group_arraylike(self) -> ArrayLike: + """ + Analogous to result_index, but holding an ArrayLike to ensure + we can retain ExtensionDtypes. + """ + if self._all_grouper is not None: + # retain dtype for categories, including unobserved ones + return self.result_index._values + + elif self._passed_categorical: + return self.group_index._values + + return self._codes_and_uniques[1] + + @cache_readonly + def result_index(self) -> Index: + # result_index retains dtype for categories, including unobserved ones, + # which group_index does not + if self._all_grouper is not None: + group_idx = self.group_index + assert isinstance(group_idx, CategoricalIndex) + cats = self._orig_cats + # set_categories is dynamically added + return group_idx.set_categories(cats) # type: ignore[attr-defined] + return self.group_index + + @cache_readonly + def group_index(self) -> Index: + codes, uniques = self._codes_and_uniques + if not self._dropna and self._passed_categorical: + assert isinstance(uniques, Categorical) + if self._sort and (codes == len(uniques)).any(): + # Add NA value on the end when sorting + uniques = Categorical.from_codes( + np.append(uniques.codes, [-1]), uniques.categories, validate=False + ) + elif len(codes) > 0: + # Need to determine proper placement of NA value when not sorting + cat = self.grouping_vector + na_idx = (cat.codes < 0).argmax() + if cat.codes[na_idx] < 0: + # count number of unique codes that comes before the nan value + na_unique_idx = algorithms.nunique_ints(cat.codes[:na_idx]) + new_codes = np.insert(uniques.codes, na_unique_idx, -1) + uniques = Categorical.from_codes( + new_codes, uniques.categories, validate=False + ) + return Index._with_infer(uniques, name=self.name) + + @cache_readonly + def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]: + uniques: ArrayLike + if self._passed_categorical: + # we make a CategoricalIndex out of the cat grouper + # preserving the categories / ordered attributes; + # doesn't (yet - GH#46909) handle dropna=False + cat = self.grouping_vector + categories = cat.categories + + if self._observed: + ucodes = algorithms.unique1d(cat.codes) + ucodes = ucodes[ucodes != -1] + if self._sort: + ucodes = np.sort(ucodes) + else: + ucodes = np.arange(len(categories)) + + uniques = Categorical.from_codes( + codes=ucodes, categories=categories, ordered=cat.ordered, validate=False + ) + + codes = cat.codes + if not self._dropna: + na_mask = codes < 0 + if np.any(na_mask): + if self._sort: + # Replace NA codes with `largest code + 1` + na_code = len(categories) + codes = np.where(na_mask, na_code, codes) + else: + # Insert NA code into the codes based on first appearance + # A negative code must exist, no need to check codes[na_idx] < 0 + na_idx = na_mask.argmax() + # count number of unique codes that comes before the nan value + na_code = algorithms.nunique_ints(codes[:na_idx]) + codes = np.where(codes >= na_code, codes + 1, codes) + codes = np.where(na_mask, na_code, codes) + + if not self._observed: + uniques = uniques.reorder_categories(self._orig_cats) + + return codes, uniques + + elif isinstance(self.grouping_vector, ops.BaseGrouper): + # we have a list of groupers + codes = self.grouping_vector.codes_info + uniques = self.grouping_vector.result_index._values + elif self._uniques is not None: + # GH#50486 Code grouping_vector using _uniques; allows + # including uniques that are not present in grouping_vector. + cat = Categorical(self.grouping_vector, categories=self._uniques) + codes = cat.codes + uniques = self._uniques + else: + # GH35667, replace dropna=False with use_na_sentinel=False + # error: Incompatible types in assignment (expression has type "Union[ + # ndarray[Any, Any], Index]", variable has type "Categorical") + codes, uniques = algorithms.factorize( # type: ignore[assignment] + self.grouping_vector, sort=self._sort, use_na_sentinel=self._dropna + ) + return codes, uniques + + @cache_readonly + def groups(self) -> dict[Hashable, np.ndarray]: + cats = Categorical.from_codes(self.codes, self.group_index, validate=False) + return self._index.groupby(cats) + + +def get_grouper( + obj: NDFrameT, + key=None, + axis: Axis = 0, + level=None, + sort: bool = True, + observed: bool = False, + validate: bool = True, + dropna: bool = True, +) -> tuple[ops.BaseGrouper, frozenset[Hashable], NDFrameT]: + """ + Create and return a BaseGrouper, which is an internal + mapping of how to create the grouper indexers. + This may be composed of multiple Grouping objects, indicating + multiple groupers + + Groupers are ultimately index mappings. They can originate as: + index mappings, keys to columns, functions, or Groupers + + Groupers enable local references to axis,level,sort, while + the passed in axis, level, and sort are 'global'. + + This routine tries to figure out what the passing in references + are and then creates a Grouping for each one, combined into + a BaseGrouper. + + If observed & we have a categorical grouper, only show the observed + values. + + If validate, then check for key/level overlaps. + + """ + group_axis = obj._get_axis(axis) + + # validate that the passed single level is compatible with the passed + # axis of the object + if level is not None: + # TODO: These if-block and else-block are almost same. + # MultiIndex instance check is removable, but it seems that there are + # some processes only for non-MultiIndex in else-block, + # eg. `obj.index.name != level`. We have to consider carefully whether + # these are applicable for MultiIndex. Even if these are applicable, + # we need to check if it makes no side effect to subsequent processes + # on the outside of this condition. + # (GH 17621) + if isinstance(group_axis, MultiIndex): + if is_list_like(level) and len(level) == 1: + level = level[0] + + if key is None and is_scalar(level): + # Get the level values from group_axis + key = group_axis.get_level_values(level) + level = None + + else: + # allow level to be a length-one list-like object + # (e.g., level=[0]) + # GH 13901 + if is_list_like(level): + nlevels = len(level) + if nlevels == 1: + level = level[0] + elif nlevels == 0: + raise ValueError("No group keys passed!") + else: + raise ValueError("multiple levels only valid with MultiIndex") + + if isinstance(level, str): + if obj._get_axis(axis).name != level: + raise ValueError( + f"level name {level} is not the name " + f"of the {obj._get_axis_name(axis)}" + ) + elif level > 0 or level < -1: + raise ValueError("level > 0 or level < -1 only valid with MultiIndex") + + # NOTE: `group_axis` and `group_axis.get_level_values(level)` + # are same in this section. + level = None + key = group_axis + + # a passed-in Grouper, directly convert + if isinstance(key, Grouper): + grouper, obj = key._get_grouper(obj, validate=False) + if key.key is None: + return grouper, frozenset(), obj + else: + return grouper, frozenset({key.key}), obj + + # already have a BaseGrouper, just return it + elif isinstance(key, ops.BaseGrouper): + return key, frozenset(), obj + + if not isinstance(key, list): + keys = [key] + match_axis_length = False + else: + keys = key + match_axis_length = len(keys) == len(group_axis) + + # what are we after, exactly? + any_callable = any(callable(g) or isinstance(g, dict) for g in keys) + any_groupers = any(isinstance(g, (Grouper, Grouping)) for g in keys) + any_arraylike = any( + isinstance(g, (list, tuple, Series, Index, np.ndarray)) for g in keys + ) + + # is this an index replacement? + if ( + not any_callable + and not any_arraylike + and not any_groupers + and match_axis_length + and level is None + ): + if isinstance(obj, DataFrame): + all_in_columns_index = all( + g in obj.columns or g in obj.index.names for g in keys + ) + else: + assert isinstance(obj, Series) + all_in_columns_index = all(g in obj.index.names for g in keys) + + if not all_in_columns_index: + keys = [com.asarray_tuplesafe(keys)] + + if isinstance(level, (tuple, list)): + if key is None: + keys = [None] * len(level) + levels = level + else: + levels = [level] * len(keys) + + groupings: list[Grouping] = [] + exclusions: set[Hashable] = set() + + # if the actual grouper should be obj[key] + def is_in_axis(key) -> bool: + if not _is_label_like(key): + if obj.ndim == 1: + return False + + # items -> .columns for DataFrame, .index for Series + items = obj.axes[-1] + try: + items.get_loc(key) + except (KeyError, TypeError, InvalidIndexError): + # TypeError shows up here if we pass e.g. an Index + return False + + return True + + # if the grouper is obj[name] + def is_in_obj(gpr) -> bool: + if not hasattr(gpr, "name"): + return False + if using_copy_on_write(): + # For the CoW case, we check the references to determine if the + # series is part of the object + try: + obj_gpr_column = obj[gpr.name] + except (KeyError, IndexError, InvalidIndexError): + return False + if isinstance(gpr, Series) and isinstance(obj_gpr_column, Series): + return gpr._mgr.references_same_values( # type: ignore[union-attr] + obj_gpr_column._mgr, 0 # type: ignore[arg-type] + ) + return False + try: + return gpr is obj[gpr.name] + except (KeyError, IndexError, InvalidIndexError): + # IndexError reached in e.g. test_skip_group_keys when we pass + # lambda here + # InvalidIndexError raised on key-types inappropriate for index, + # e.g. DatetimeIndex.get_loc(tuple()) + return False + + for gpr, level in zip(keys, levels): + if is_in_obj(gpr): # df.groupby(df['name']) + in_axis = True + exclusions.add(gpr.name) + + elif is_in_axis(gpr): # df.groupby('name') + if obj.ndim != 1 and gpr in obj: + if validate: + obj._check_label_or_level_ambiguity(gpr, axis=axis) + in_axis, name, gpr = True, gpr, obj[gpr] + if gpr.ndim != 1: + # non-unique columns; raise here to get the name in the + # exception message + raise ValueError(f"Grouper for '{name}' not 1-dimensional") + exclusions.add(name) + elif obj._is_level_reference(gpr, axis=axis): + in_axis, level, gpr = False, gpr, None + else: + raise KeyError(gpr) + elif isinstance(gpr, Grouper) and gpr.key is not None: + # Add key to exclusions + exclusions.add(gpr.key) + in_axis = True + else: + in_axis = False + + # create the Grouping + # allow us to passing the actual Grouping as the gpr + ping = ( + Grouping( + group_axis, + gpr, + obj=obj, + level=level, + sort=sort, + observed=observed, + in_axis=in_axis, + dropna=dropna, + ) + if not isinstance(gpr, Grouping) + else gpr + ) + + groupings.append(ping) + + if len(groupings) == 0 and len(obj): + raise ValueError("No group keys passed!") + if len(groupings) == 0: + groupings.append(Grouping(Index([], dtype="int"), np.array([], dtype=np.intp))) + + # create the internals grouper + grouper = ops.BaseGrouper(group_axis, groupings, sort=sort, dropna=dropna) + return grouper, frozenset(exclusions), obj + + +def _is_label_like(val) -> bool: + return isinstance(val, (str, tuple)) or (val is not None and is_scalar(val)) + + +def _convert_grouper(axis: Index, grouper): + if isinstance(grouper, dict): + return grouper.get + elif isinstance(grouper, Series): + if grouper.index.equals(axis): + return grouper._values + else: + return grouper.reindex(axis)._values + elif isinstance(grouper, MultiIndex): + return grouper._values + elif isinstance(grouper, (list, tuple, Index, Categorical, np.ndarray)): + if len(grouper) != len(axis): + raise ValueError("Grouper and axis must be same length") + + if isinstance(grouper, (list, tuple)): + grouper = com.asarray_tuplesafe(grouper) + return grouper + else: + return grouper diff --git a/pandas/core/indexes/range 2.py b/pandas/core/indexes/range 2.py new file mode 100644 index 0000000000000..ca415d2089ecf --- /dev/null +++ b/pandas/core/indexes/range 2.py @@ -0,0 +1,1107 @@ +from __future__ import annotations + +from collections.abc import ( + Hashable, + Iterator, +) +from datetime import timedelta +import operator +from sys import getsizeof +from typing import ( + TYPE_CHECKING, + Any, + Callable, + cast, +) + +import numpy as np + +from pandas._libs import ( + index as libindex, + lib, +) +from pandas._libs.algos import unique_deltas +from pandas._libs.lib import no_default +from pandas.compat.numpy import function as nv +from pandas.util._decorators import ( + cache_readonly, + doc, +) + +from pandas.core.dtypes.common import ( + ensure_platform_int, + ensure_python_int, + is_float, + is_integer, + is_scalar, + is_signed_integer_dtype, +) +from pandas.core.dtypes.generic import ABCTimedeltaIndex + +from pandas.core import ops +import pandas.core.common as com +from pandas.core.construction import extract_array +import pandas.core.indexes.base as ibase +from pandas.core.indexes.base import ( + Index, + maybe_extract_name, +) +from pandas.core.ops.common import unpack_zerodim_and_defer + +if TYPE_CHECKING: + from pandas._typing import ( + Dtype, + NaPosition, + Self, + npt, + ) +_empty_range = range(0) +_dtype_int64 = np.dtype(np.int64) + + +class RangeIndex(Index): + """ + Immutable Index implementing a monotonic integer range. + + RangeIndex is a memory-saving special case of an Index limited to representing + monotonic ranges with a 64-bit dtype. Using RangeIndex may in some instances + improve computing speed. + + This is the default index type used + by DataFrame and Series when no explicit index is provided by the user. + + Parameters + ---------- + start : int (default: 0), range, or other RangeIndex instance + If int and "stop" is not given, interpreted as "stop" instead. + stop : int (default: 0) + step : int (default: 1) + dtype : np.int64 + Unused, accepted for homogeneity with other index types. + copy : bool, default False + Unused, accepted for homogeneity with other index types. + name : object, optional + Name to be stored in the index. + + Attributes + ---------- + start + stop + step + + Methods + ------- + from_range + + See Also + -------- + Index : The base pandas Index type. + + Examples + -------- + >>> list(pd.RangeIndex(5)) + [0, 1, 2, 3, 4] + + >>> list(pd.RangeIndex(-2, 4)) + [-2, -1, 0, 1, 2, 3] + + >>> list(pd.RangeIndex(0, 10, 2)) + [0, 2, 4, 6, 8] + + >>> list(pd.RangeIndex(2, -10, -3)) + [2, -1, -4, -7] + + >>> list(pd.RangeIndex(0)) + [] + + >>> list(pd.RangeIndex(1, 0)) + [] + """ + + _typ = "rangeindex" + _dtype_validation_metadata = (is_signed_integer_dtype, "signed integer") + _range: range + _values: np.ndarray + + @property + def _engine_type(self) -> type[libindex.Int64Engine]: + return libindex.Int64Engine + + # -------------------------------------------------------------------- + # Constructors + + def __new__( + cls, + start=None, + stop=None, + step=None, + dtype: Dtype | None = None, + copy: bool = False, + name: Hashable | None = None, + ) -> RangeIndex: + cls._validate_dtype(dtype) + name = maybe_extract_name(name, start, cls) + + # RangeIndex + if isinstance(start, RangeIndex): + return start.copy(name=name) + elif isinstance(start, range): + return cls._simple_new(start, name=name) + + # validate the arguments + if com.all_none(start, stop, step): + raise TypeError("RangeIndex(...) must be called with integers") + + start = ensure_python_int(start) if start is not None else 0 + + if stop is None: + start, stop = 0, start + else: + stop = ensure_python_int(stop) + + step = ensure_python_int(step) if step is not None else 1 + if step == 0: + raise ValueError("Step must not be zero") + + rng = range(start, stop, step) + return cls._simple_new(rng, name=name) + + @classmethod + def from_range(cls, data: range, name=None, dtype: Dtype | None = None) -> Self: + """ + Create :class:`pandas.RangeIndex` from a ``range`` object. + + Returns + ------- + RangeIndex + + Examples + -------- + >>> pd.RangeIndex.from_range(range(5)) + RangeIndex(start=0, stop=5, step=1) + + >>> pd.RangeIndex.from_range(range(2, -10, -3)) + RangeIndex(start=2, stop=-10, step=-3) + """ + if not isinstance(data, range): + raise TypeError( + f"{cls.__name__}(...) must be called with object coercible to a " + f"range, {repr(data)} was passed" + ) + cls._validate_dtype(dtype) + return cls._simple_new(data, name=name) + + # error: Argument 1 of "_simple_new" is incompatible with supertype "Index"; + # supertype defines the argument type as + # "Union[ExtensionArray, ndarray[Any, Any]]" [override] + @classmethod + def _simple_new( # type: ignore[override] + cls, values: range, name: Hashable | None = None + ) -> Self: + result = object.__new__(cls) + + assert isinstance(values, range) + + result._range = values + result._name = name + result._cache = {} + result._reset_identity() + result._references = None + return result + + @classmethod + def _validate_dtype(cls, dtype: Dtype | None) -> None: + if dtype is None: + return + + validation_func, expected = cls._dtype_validation_metadata + if not validation_func(dtype): + raise ValueError( + f"Incorrect `dtype` passed: expected {expected}, received {dtype}" + ) + + # -------------------------------------------------------------------- + + # error: Return type "Type[Index]" of "_constructor" incompatible with return + # type "Type[RangeIndex]" in supertype "Index" + @cache_readonly + def _constructor(self) -> type[Index]: # type: ignore[override] + """return the class to use for construction""" + return Index + + # error: Signature of "_data" incompatible with supertype "Index" + @cache_readonly + def _data(self) -> np.ndarray: # type: ignore[override] + """ + An int array that for performance reasons is created only when needed. + + The constructed array is saved in ``_cache``. + """ + return np.arange(self.start, self.stop, self.step, dtype=np.int64) + + def _get_data_as_items(self): + """return a list of tuples of start, stop, step""" + rng = self._range + return [("start", rng.start), ("stop", rng.stop), ("step", rng.step)] + + def __reduce__(self): + d = {"name": self._name} + d.update(dict(self._get_data_as_items())) + return ibase._new_Index, (type(self), d), None + + # -------------------------------------------------------------------- + # Rendering Methods + + def _format_attrs(self): + """ + Return a list of tuples of the (attr, formatted_value) + """ + attrs = self._get_data_as_items() + if self._name is not None: + attrs.append(("name", ibase.default_pprint(self._name))) + return attrs + + def _format_data(self, name=None): + # we are formatting thru the attributes + return None + + def _format_with_header(self, header: list[str], na_rep: str) -> list[str]: + # Equivalent to Index implementation, but faster + if not len(self._range): + return header + first_val_str = str(self._range[0]) + last_val_str = str(self._range[-1]) + max_length = max(len(first_val_str), len(last_val_str)) + + return header + [f"{x:<{max_length}}" for x in self._range] + + # -------------------------------------------------------------------- + + @property + def start(self) -> int: + """ + The value of the `start` parameter (``0`` if this was not supplied). + + Examples + -------- + >>> idx = pd.RangeIndex(5) + >>> idx.start + 0 + + >>> idx = pd.RangeIndex(2, -10, -3) + >>> idx.start + 2 + """ + # GH 25710 + return self._range.start + + @property + def stop(self) -> int: + """ + The value of the `stop` parameter. + + Examples + -------- + >>> idx = pd.RangeIndex(5) + >>> idx.stop + 5 + + >>> idx = pd.RangeIndex(2, -10, -3) + >>> idx.stop + -10 + """ + return self._range.stop + + @property + def step(self) -> int: + """ + The value of the `step` parameter (``1`` if this was not supplied). + + Examples + -------- + >>> idx = pd.RangeIndex(5) + >>> idx.step + 1 + + >>> idx = pd.RangeIndex(2, -10, -3) + >>> idx.step + -3 + + Even if :class:`pandas.RangeIndex` is empty, ``step`` is still ``1`` if + not supplied. + + >>> idx = pd.RangeIndex(1, 0) + >>> idx.step + 1 + """ + # GH 25710 + return self._range.step + + @cache_readonly + def nbytes(self) -> int: + """ + Return the number of bytes in the underlying data. + """ + rng = self._range + return getsizeof(rng) + sum( + getsizeof(getattr(rng, attr_name)) + for attr_name in ["start", "stop", "step"] + ) + + def memory_usage(self, deep: bool = False) -> int: + """ + Memory usage of my values + + Parameters + ---------- + deep : bool + Introspect the data deeply, interrogate + `object` dtypes for system-level memory consumption + + Returns + ------- + bytes used + + Notes + ----- + Memory usage does not include memory consumed by elements that + are not components of the array if deep=False + + See Also + -------- + numpy.ndarray.nbytes + """ + return self.nbytes + + @property + def dtype(self) -> np.dtype: + return _dtype_int64 + + @property + def is_unique(self) -> bool: + """return if the index has unique values""" + return True + + @cache_readonly + def is_monotonic_increasing(self) -> bool: + return self._range.step > 0 or len(self) <= 1 + + @cache_readonly + def is_monotonic_decreasing(self) -> bool: + return self._range.step < 0 or len(self) <= 1 + + def __contains__(self, key: Any) -> bool: + hash(key) + try: + key = ensure_python_int(key) + except TypeError: + return False + return key in self._range + + @property + def inferred_type(self) -> str: + return "integer" + + # -------------------------------------------------------------------- + # Indexing Methods + + @doc(Index.get_loc) + def get_loc(self, key): + if is_integer(key) or (is_float(key) and key.is_integer()): + new_key = int(key) + try: + return self._range.index(new_key) + except ValueError as err: + raise KeyError(key) from err + if isinstance(key, Hashable): + raise KeyError(key) + self._check_indexing_error(key) + raise KeyError(key) + + def _get_indexer( + self, + target: Index, + method: str | None = None, + limit: int | None = None, + tolerance=None, + ) -> npt.NDArray[np.intp]: + if com.any_not_none(method, tolerance, limit): + return super()._get_indexer( + target, method=method, tolerance=tolerance, limit=limit + ) + + if self.step > 0: + start, stop, step = self.start, self.stop, self.step + else: + # GH 28678: work on reversed range for simplicity + reverse = self._range[::-1] + start, stop, step = reverse.start, reverse.stop, reverse.step + + target_array = np.asarray(target) + locs = target_array - start + valid = (locs % step == 0) & (locs >= 0) & (target_array < stop) + locs[~valid] = -1 + locs[valid] = locs[valid] / step + + if step != self.step: + # We reversed this range: transform to original locs + locs[valid] = len(self) - 1 - locs[valid] + return ensure_platform_int(locs) + + @cache_readonly + def _should_fallback_to_positional(self) -> bool: + """ + Should an integer key be treated as positional? + """ + return False + + # -------------------------------------------------------------------- + + def tolist(self) -> list[int]: + return list(self._range) + + @doc(Index.__iter__) + def __iter__(self) -> Iterator[int]: + yield from self._range + + @doc(Index._shallow_copy) + def _shallow_copy(self, values, name: Hashable = no_default): + name = self._name if name is no_default else name + + if values.dtype.kind == "f": + return Index(values, name=name, dtype=np.float64) + # GH 46675 & 43885: If values is equally spaced, return a + # more memory-compact RangeIndex instead of Index with 64-bit dtype + unique_diffs = unique_deltas(values) + if len(unique_diffs) == 1 and unique_diffs[0] != 0: + diff = unique_diffs[0] + new_range = range(values[0], values[-1] + diff, diff) + return type(self)._simple_new(new_range, name=name) + else: + return self._constructor._simple_new(values, name=name) + + def _view(self) -> Self: + result = type(self)._simple_new(self._range, name=self._name) + result._cache = self._cache + return result + + @doc(Index.copy) + def copy(self, name: Hashable | None = None, deep: bool = False) -> Self: + name = self._validate_names(name=name, deep=deep)[0] + new_index = self._rename(name=name) + return new_index + + def _minmax(self, meth: str): + no_steps = len(self) - 1 + if no_steps == -1: + return np.nan + elif (meth == "min" and self.step > 0) or (meth == "max" and self.step < 0): + return self.start + + return self.start + self.step * no_steps + + def min(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: + """The minimum value of the RangeIndex""" + nv.validate_minmax_axis(axis) + nv.validate_min(args, kwargs) + return self._minmax("min") + + def max(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: + """The maximum value of the RangeIndex""" + nv.validate_minmax_axis(axis) + nv.validate_max(args, kwargs) + return self._minmax("max") + + def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: + """ + Returns the indices that would sort the index and its + underlying data. + + Returns + ------- + np.ndarray[np.intp] + + See Also + -------- + numpy.ndarray.argsort + """ + ascending = kwargs.pop("ascending", True) # EA compat + kwargs.pop("kind", None) # e.g. "mergesort" is irrelevant + nv.validate_argsort(args, kwargs) + + if self._range.step > 0: + result = np.arange(len(self), dtype=np.intp) + else: + result = np.arange(len(self) - 1, -1, -1, dtype=np.intp) + + if not ascending: + result = result[::-1] + return result + + def factorize( + self, + sort: bool = False, + use_na_sentinel: bool = True, + ) -> tuple[npt.NDArray[np.intp], RangeIndex]: + codes = np.arange(len(self), dtype=np.intp) + uniques = self + if sort and self.step < 0: + codes = codes[::-1] + uniques = uniques[::-1] + return codes, uniques + + def equals(self, other: object) -> bool: + """ + Determines if two Index objects contain the same elements. + """ + if isinstance(other, RangeIndex): + return self._range == other._range + return super().equals(other) + + def sort_values( + self, + return_indexer: bool = False, + ascending: bool = True, + na_position: NaPosition = "last", + key: Callable | None = None, + ): + if key is not None: + return super().sort_values( + return_indexer=return_indexer, + ascending=ascending, + na_position=na_position, + key=key, + ) + else: + sorted_index = self + inverse_indexer = False + if ascending: + if self.step < 0: + sorted_index = self[::-1] + inverse_indexer = True + else: + if self.step > 0: + sorted_index = self[::-1] + inverse_indexer = True + + if return_indexer: + if inverse_indexer: + rng = range(len(self) - 1, -1, -1) + else: + rng = range(len(self)) + return sorted_index, RangeIndex(rng) + else: + return sorted_index + + # -------------------------------------------------------------------- + # Set Operations + + def _intersection(self, other: Index, sort: bool = False): + # caller is responsible for checking self and other are both non-empty + + if not isinstance(other, RangeIndex): + return super()._intersection(other, sort=sort) + + first = self._range[::-1] if self.step < 0 else self._range + second = other._range[::-1] if other.step < 0 else other._range + + # check whether intervals intersect + # deals with in- and decreasing ranges + int_low = max(first.start, second.start) + int_high = min(first.stop, second.stop) + if int_high <= int_low: + return self._simple_new(_empty_range) + + # Method hint: linear Diophantine equation + # solve intersection problem + # performance hint: for identical step sizes, could use + # cheaper alternative + gcd, s, _ = self._extended_gcd(first.step, second.step) + + # check whether element sets intersect + if (first.start - second.start) % gcd: + return self._simple_new(_empty_range) + + # calculate parameters for the RangeIndex describing the + # intersection disregarding the lower bounds + tmp_start = first.start + (second.start - first.start) * first.step // gcd * s + new_step = first.step * second.step // gcd + new_range = range(tmp_start, int_high, new_step) + new_index = self._simple_new(new_range) + + # adjust index to limiting interval + new_start = new_index._min_fitting_element(int_low) + new_range = range(new_start, new_index.stop, new_index.step) + new_index = self._simple_new(new_range) + + if (self.step < 0 and other.step < 0) is not (new_index.step < 0): + new_index = new_index[::-1] + + if sort is None: + new_index = new_index.sort_values() + + return new_index + + def _min_fitting_element(self, lower_limit: int) -> int: + """Returns the smallest element greater than or equal to the limit""" + no_steps = -(-(lower_limit - self.start) // abs(self.step)) + return self.start + abs(self.step) * no_steps + + def _extended_gcd(self, a: int, b: int) -> tuple[int, int, int]: + """ + Extended Euclidean algorithms to solve Bezout's identity: + a*x + b*y = gcd(x, y) + Finds one particular solution for x, y: s, t + Returns: gcd, s, t + """ + s, old_s = 0, 1 + t, old_t = 1, 0 + r, old_r = b, a + while r: + quotient = old_r // r + old_r, r = r, old_r - quotient * r + old_s, s = s, old_s - quotient * s + old_t, t = t, old_t - quotient * t + return old_r, old_s, old_t + + def _range_in_self(self, other: range) -> bool: + """Check if other range is contained in self""" + # https://stackoverflow.com/a/32481015 + if not other: + return True + if not self._range: + return False + if len(other) > 1 and other.step % self._range.step: + return False + return other.start in self._range and other[-1] in self._range + + def _union(self, other: Index, sort: bool | None): + """ + Form the union of two Index objects and sorts if possible + + Parameters + ---------- + other : Index or array-like + + sort : bool or None, default None + Whether to sort (monotonically increasing) the resulting index. + ``sort=None|True`` returns a ``RangeIndex`` if possible or a sorted + ``Index`` with a int64 dtype if not. + ``sort=False`` can return a ``RangeIndex`` if self is monotonically + increasing and other is fully contained in self. Otherwise, returns + an unsorted ``Index`` with an int64 dtype. + + Returns + ------- + union : Index + """ + if isinstance(other, RangeIndex): + if sort in (None, True) or ( + sort is False and self.step > 0 and self._range_in_self(other._range) + ): + # GH 47557: Can still return a RangeIndex + # if other range in self and sort=False + start_s, step_s = self.start, self.step + end_s = self.start + self.step * (len(self) - 1) + start_o, step_o = other.start, other.step + end_o = other.start + other.step * (len(other) - 1) + if self.step < 0: + start_s, step_s, end_s = end_s, -step_s, start_s + if other.step < 0: + start_o, step_o, end_o = end_o, -step_o, start_o + if len(self) == 1 and len(other) == 1: + step_s = step_o = abs(self.start - other.start) + elif len(self) == 1: + step_s = step_o + elif len(other) == 1: + step_o = step_s + start_r = min(start_s, start_o) + end_r = max(end_s, end_o) + if step_o == step_s: + if ( + (start_s - start_o) % step_s == 0 + and (start_s - end_o) <= step_s + and (start_o - end_s) <= step_s + ): + return type(self)(start_r, end_r + step_s, step_s) + if ( + (step_s % 2 == 0) + and (abs(start_s - start_o) == step_s / 2) + and (abs(end_s - end_o) == step_s / 2) + ): + # e.g. range(0, 10, 2) and range(1, 11, 2) + # but not range(0, 20, 4) and range(1, 21, 4) GH#44019 + return type(self)(start_r, end_r + step_s / 2, step_s / 2) + + elif step_o % step_s == 0: + if ( + (start_o - start_s) % step_s == 0 + and (start_o + step_s >= start_s) + and (end_o - step_s <= end_s) + ): + return type(self)(start_r, end_r + step_s, step_s) + elif step_s % step_o == 0: + if ( + (start_s - start_o) % step_o == 0 + and (start_s + step_o >= start_o) + and (end_s - step_o <= end_o) + ): + return type(self)(start_r, end_r + step_o, step_o) + + return super()._union(other, sort=sort) + + def _difference(self, other, sort=None): + # optimized set operation if we have another RangeIndex + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, result_name = self._convert_can_do_setop(other) + + if not isinstance(other, RangeIndex): + return super()._difference(other, sort=sort) + + if sort is not False and self.step < 0: + return self[::-1]._difference(other) + + res_name = ops.get_op_result_name(self, other) + + first = self._range[::-1] if self.step < 0 else self._range + overlap = self.intersection(other) + if overlap.step < 0: + overlap = overlap[::-1] + + if len(overlap) == 0: + return self.rename(name=res_name) + if len(overlap) == len(self): + return self[:0].rename(res_name) + + # overlap.step will always be a multiple of self.step (see _intersection) + + if len(overlap) == 1: + if overlap[0] == self[0]: + return self[1:] + + elif overlap[0] == self[-1]: + return self[:-1] + + elif len(self) == 3 and overlap[0] == self[1]: + return self[::2] + + else: + return super()._difference(other, sort=sort) + + elif len(overlap) == 2 and overlap[0] == first[0] and overlap[-1] == first[-1]: + # e.g. range(-8, 20, 7) and range(13, -9, -3) + return self[1:-1] + + if overlap.step == first.step: + if overlap[0] == first.start: + # The difference is everything after the intersection + new_rng = range(overlap[-1] + first.step, first.stop, first.step) + elif overlap[-1] == first[-1]: + # The difference is everything before the intersection + new_rng = range(first.start, overlap[0], first.step) + elif overlap._range == first[1:-1]: + # e.g. range(4) and range(1, 3) + step = len(first) - 1 + new_rng = first[::step] + else: + # The difference is not range-like + # e.g. range(1, 10, 1) and range(3, 7, 1) + return super()._difference(other, sort=sort) + + else: + # We must have len(self) > 1, bc we ruled out above + # len(overlap) == 0 and len(overlap) == len(self) + assert len(self) > 1 + + if overlap.step == first.step * 2: + if overlap[0] == first[0] and overlap[-1] in (first[-1], first[-2]): + # e.g. range(1, 10, 1) and range(1, 10, 2) + new_rng = first[1::2] + + elif overlap[0] == first[1] and overlap[-1] in (first[-1], first[-2]): + # e.g. range(1, 10, 1) and range(2, 10, 2) + new_rng = first[::2] + + else: + # We can get here with e.g. range(20) and range(0, 10, 2) + return super()._difference(other, sort=sort) + + else: + # e.g. range(10) and range(0, 10, 3) + return super()._difference(other, sort=sort) + + new_index = type(self)._simple_new(new_rng, name=res_name) + if first is not self._range: + new_index = new_index[::-1] + + return new_index + + def symmetric_difference( + self, other, result_name: Hashable | None = None, sort=None + ): + if not isinstance(other, RangeIndex) or sort is not None: + return super().symmetric_difference(other, result_name, sort) + + left = self.difference(other) + right = other.difference(self) + result = left.union(right) + + if result_name is not None: + result = result.rename(result_name) + return result + + # -------------------------------------------------------------------- + + # error: Return type "Index" of "delete" incompatible with return type + # "RangeIndex" in supertype "Index" + def delete(self, loc) -> Index: # type: ignore[override] + # In some cases we can retain RangeIndex, see also + # DatetimeTimedeltaMixin._get_delete_Freq + if is_integer(loc): + if loc in (0, -len(self)): + return self[1:] + if loc in (-1, len(self) - 1): + return self[:-1] + if len(self) == 3 and loc in (1, -2): + return self[::2] + + elif lib.is_list_like(loc): + slc = lib.maybe_indices_to_slice(np.asarray(loc, dtype=np.intp), len(self)) + + if isinstance(slc, slice): + # defer to RangeIndex._difference, which is optimized to return + # a RangeIndex whenever possible + other = self[slc] + return self.difference(other, sort=False) + + return super().delete(loc) + + def insert(self, loc: int, item) -> Index: + if len(self) and (is_integer(item) or is_float(item)): + # We can retain RangeIndex is inserting at the beginning or end, + # or right in the middle. + rng = self._range + if loc == 0 and item == self[0] - self.step: + new_rng = range(rng.start - rng.step, rng.stop, rng.step) + return type(self)._simple_new(new_rng, name=self._name) + + elif loc == len(self) and item == self[-1] + self.step: + new_rng = range(rng.start, rng.stop + rng.step, rng.step) + return type(self)._simple_new(new_rng, name=self._name) + + elif len(self) == 2 and item == self[0] + self.step / 2: + # e.g. inserting 1 into [0, 2] + step = int(self.step / 2) + new_rng = range(self.start, self.stop, step) + return type(self)._simple_new(new_rng, name=self._name) + + return super().insert(loc, item) + + def _concat(self, indexes: list[Index], name: Hashable) -> Index: + """ + Overriding parent method for the case of all RangeIndex instances. + + When all members of "indexes" are of type RangeIndex: result will be + RangeIndex if possible, Index with a int64 dtype otherwise. E.g.: + indexes = [RangeIndex(3), RangeIndex(3, 6)] -> RangeIndex(6) + indexes = [RangeIndex(3), RangeIndex(4, 6)] -> Index([0,1,2,4,5], dtype='int64') + """ + if not all(isinstance(x, RangeIndex) for x in indexes): + return super()._concat(indexes, name) + + elif len(indexes) == 1: + return indexes[0] + + rng_indexes = cast(list[RangeIndex], indexes) + + start = step = next_ = None + + # Filter the empty indexes + non_empty_indexes = [obj for obj in rng_indexes if len(obj)] + + for obj in non_empty_indexes: + rng = obj._range + + if start is None: + # This is set by the first non-empty index + start = rng.start + if step is None and len(rng) > 1: + step = rng.step + elif step is None: + # First non-empty index had only one element + if rng.start == start: + values = np.concatenate([x._values for x in rng_indexes]) + result = self._constructor(values) + return result.rename(name) + + step = rng.start - start + + non_consecutive = (step != rng.step and len(rng) > 1) or ( + next_ is not None and rng.start != next_ + ) + if non_consecutive: + result = self._constructor( + np.concatenate([x._values for x in rng_indexes]) + ) + return result.rename(name) + + if step is not None: + next_ = rng[-1] + step + + if non_empty_indexes: + # Get the stop value from "next" or alternatively + # from the last non-empty index + stop = non_empty_indexes[-1].stop if next_ is None else next_ + return RangeIndex(start, stop, step).rename(name) + + # Here all "indexes" had 0 length, i.e. were empty. + # In this case return an empty range index. + return RangeIndex(0, 0).rename(name) + + def __len__(self) -> int: + """ + return the length of the RangeIndex + """ + return len(self._range) + + @property + def size(self) -> int: + return len(self) + + def __getitem__(self, key): + """ + Conserve RangeIndex type for scalar and slice keys. + """ + if isinstance(key, slice): + return self._getitem_slice(key) + elif is_integer(key): + new_key = int(key) + try: + return self._range[new_key] + except IndexError as err: + raise IndexError( + f"index {key} is out of bounds for axis 0 with size {len(self)}" + ) from err + elif is_scalar(key): + raise IndexError( + "only integers, slices (`:`), " + "ellipsis (`...`), numpy.newaxis (`None`) " + "and integer or boolean " + "arrays are valid indices" + ) + return super().__getitem__(key) + + def _getitem_slice(self, slobj: slice) -> Self: + """ + Fastpath for __getitem__ when we know we have a slice. + """ + res = self._range[slobj] + return type(self)._simple_new(res, name=self._name) + + @unpack_zerodim_and_defer("__floordiv__") + def __floordiv__(self, other): + if is_integer(other) and other != 0: + if len(self) == 0 or self.start % other == 0 and self.step % other == 0: + start = self.start // other + step = self.step // other + stop = start + len(self) * step + new_range = range(start, stop, step or 1) + return self._simple_new(new_range, name=self._name) + if len(self) == 1: + start = self.start // other + new_range = range(start, start + 1, 1) + return self._simple_new(new_range, name=self._name) + + return super().__floordiv__(other) + + # -------------------------------------------------------------------- + # Reductions + + def all(self, *args, **kwargs) -> bool: + return 0 not in self._range + + def any(self, *args, **kwargs) -> bool: + return any(self._range) + + # -------------------------------------------------------------------- + + def _cmp_method(self, other, op): + if isinstance(other, RangeIndex) and self._range == other._range: + # Both are immutable so if ._range attr. are equal, shortcut is possible + return super()._cmp_method(self, op) + return super()._cmp_method(other, op) + + def _arith_method(self, other, op): + """ + Parameters + ---------- + other : Any + op : callable that accepts 2 params + perform the binary op + """ + + if isinstance(other, ABCTimedeltaIndex): + # Defer to TimedeltaIndex implementation + return NotImplemented + elif isinstance(other, (timedelta, np.timedelta64)): + # GH#19333 is_integer evaluated True on timedelta64, + # so we need to catch these explicitly + return super()._arith_method(other, op) + elif lib.is_np_dtype(getattr(other, "dtype", None), "m"): + # Must be an np.ndarray; GH#22390 + return super()._arith_method(other, op) + + if op in [ + operator.pow, + ops.rpow, + operator.mod, + ops.rmod, + operator.floordiv, + ops.rfloordiv, + divmod, + ops.rdivmod, + ]: + return super()._arith_method(other, op) + + step: Callable | None = None + if op in [operator.mul, ops.rmul, operator.truediv, ops.rtruediv]: + step = op + + # TODO: if other is a RangeIndex we may have more efficient options + right = extract_array(other, extract_numpy=True, extract_range=True) + left = self + + try: + # apply if we have an override + if step: + with np.errstate(all="ignore"): + rstep = step(left.step, right) + + # we don't have a representable op + # so return a base index + if not is_integer(rstep) or not rstep: + raise ValueError + + # GH#53255 + else: + rstep = -left.step if op == ops.rsub else left.step + + with np.errstate(all="ignore"): + rstart = op(left.start, right) + rstop = op(left.stop, right) + + res_name = ops.get_op_result_name(self, other) + result = type(self)(rstart, rstop, rstep, name=res_name) + + # for compat with numpy / Index with int64 dtype + # even if we can represent as a RangeIndex, return + # as a float64 Index if we have float-like descriptors + if not all(is_integer(x) for x in [rstart, rstop, rstep]): + result = result.astype("float64") + + return result + + except (ValueError, TypeError, ZeroDivisionError): + # test_arithmetic_explicit_conversions + return super()._arith_method(other, op) diff --git a/pandas/core/reshape/merge 2.py b/pandas/core/reshape/merge 2.py new file mode 100644 index 0000000000000..a904f4d9fbe13 --- /dev/null +++ b/pandas/core/reshape/merge 2.py @@ -0,0 +1,2650 @@ +""" +SQL-style merge routines +""" +from __future__ import annotations + +from collections.abc import ( + Hashable, + Sequence, +) +import datetime +from functools import partial +import string +from typing import ( + TYPE_CHECKING, + Literal, + cast, + final, +) +import uuid +import warnings + +import numpy as np + +from pandas._libs import ( + Timedelta, + hashtable as libhashtable, + join as libjoin, + lib, +) +from pandas._libs.lib import is_range_indexer +from pandas._typing import ( + AnyArrayLike, + ArrayLike, + IndexLabel, + JoinHow, + MergeHow, + Shape, + Suffixes, + npt, +) +from pandas.errors import MergeError +from pandas.util._decorators import ( + Appender, + Substitution, + cache_readonly, +) +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.cast import find_common_type +from pandas.core.dtypes.common import ( + ensure_int64, + ensure_object, + is_bool, + is_bool_dtype, + is_float_dtype, + is_integer, + is_integer_dtype, + is_list_like, + is_number, + is_numeric_dtype, + is_object_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) +from pandas.core.dtypes.missing import ( + isna, + na_value_for_dtype, +) + +from pandas import ( + ArrowDtype, + Categorical, + Index, + MultiIndex, + Series, +) +import pandas.core.algorithms as algos +from pandas.core.arrays import ( + ArrowExtensionArray, + BaseMaskedArray, + ExtensionArray, +) +from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +import pandas.core.common as com +from pandas.core.construction import ( + ensure_wrapped_if_datetimelike, + extract_array, +) +from pandas.core.frame import _merge_doc +from pandas.core.indexes.api import default_index +from pandas.core.sorting import is_int64_overflow_possible + +if TYPE_CHECKING: + from pandas import DataFrame + from pandas.core import groupby + from pandas.core.arrays import DatetimeArray + +_factorizers = { + np.int64: libhashtable.Int64Factorizer, + np.longlong: libhashtable.Int64Factorizer, + np.int32: libhashtable.Int32Factorizer, + np.int16: libhashtable.Int16Factorizer, + np.int8: libhashtable.Int8Factorizer, + np.uint64: libhashtable.UInt64Factorizer, + np.uint32: libhashtable.UInt32Factorizer, + np.uint16: libhashtable.UInt16Factorizer, + np.uint8: libhashtable.UInt8Factorizer, + np.bool_: libhashtable.UInt8Factorizer, + np.float64: libhashtable.Float64Factorizer, + np.float32: libhashtable.Float32Factorizer, + np.complex64: libhashtable.Complex64Factorizer, + np.complex128: libhashtable.Complex128Factorizer, + np.object_: libhashtable.ObjectFactorizer, +} + +# See https://github.com/pandas-dev/pandas/issues/52451 +if np.intc is not np.int32: + _factorizers[np.intc] = libhashtable.Int64Factorizer + +_known = (np.ndarray, ExtensionArray, Index, ABCSeries) + + +@Substitution("\nleft : DataFrame or named Series") +@Appender(_merge_doc, indents=0) +def merge( + left: DataFrame | Series, + right: DataFrame | Series, + how: MergeHow = "inner", + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool | None = None, + indicator: str | bool = False, + validate: str | None = None, +) -> DataFrame: + left_df = _validate_operand(left) + right_df = _validate_operand(right) + if how == "cross": + return _cross_merge( + left_df, + right_df, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + indicator=indicator, + validate=validate, + copy=copy, + ) + else: + op = _MergeOperation( + left_df, + right_df, + how=how, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + indicator=indicator, + validate=validate, + ) + return op.get_result(copy=copy) + + +def _cross_merge( + left: DataFrame, + right: DataFrame, + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool | None = None, + indicator: str | bool = False, + validate: str | None = None, +) -> DataFrame: + """ + See merge.__doc__ with how='cross' + """ + + if ( + left_index + or right_index + or right_on is not None + or left_on is not None + or on is not None + ): + raise MergeError( + "Can not pass on, right_on, left_on or set right_index=True or " + "left_index=True" + ) + + cross_col = f"_cross_{uuid.uuid4()}" + left = left.assign(**{cross_col: 1}) + right = right.assign(**{cross_col: 1}) + + left_on = right_on = [cross_col] + + res = merge( + left, + right, + how="inner", + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + indicator=indicator, + validate=validate, + copy=copy, + ) + del res[cross_col] + return res + + +def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces): + """ + groupby & merge; we are always performing a left-by type operation + + Parameters + ---------- + by: field to group + left: DataFrame + right: DataFrame + merge_pieces: function for merging + """ + pieces = [] + if not isinstance(by, (list, tuple)): + by = [by] + + lby = left.groupby(by, sort=False) + rby: groupby.DataFrameGroupBy | None = None + + # if we can groupby the rhs + # then we can get vastly better perf + if all(item in right.columns for item in by): + rby = right.groupby(by, sort=False) + + for key, lhs in lby.grouper.get_iterator(lby._selected_obj, axis=lby.axis): + if rby is None: + rhs = right + else: + try: + rhs = right.take(rby.indices[key]) + except KeyError: + # key doesn't exist in left + lcols = lhs.columns.tolist() + cols = lcols + [r for r in right.columns if r not in set(lcols)] + merged = lhs.reindex(columns=cols) + merged.index = range(len(merged)) + pieces.append(merged) + continue + + merged = merge_pieces(lhs, rhs) + + # make sure join keys are in the merged + # TODO, should merge_pieces do this? + merged[by] = key + + pieces.append(merged) + + # preserve the original order + # if we have a missing piece this can be reset + from pandas.core.reshape.concat import concat + + result = concat(pieces, ignore_index=True) + result = result.reindex(columns=pieces[0].columns, copy=False) + return result, lby + + +def merge_ordered( + left: DataFrame, + right: DataFrame, + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_by=None, + right_by=None, + fill_method: str | None = None, + suffixes: Suffixes = ("_x", "_y"), + how: JoinHow = "outer", +) -> DataFrame: + """ + Perform a merge for ordered data with optional filling/interpolation. + + Designed for ordered data like time series data. Optionally + perform group-wise merge (see examples). + + Parameters + ---------- + left : DataFrame or named Series + right : DataFrame or named Series + on : label or list + Field names to join on. Must be found in both DataFrames. + left_on : label or list, or array-like + Field names to join on in left DataFrame. Can be a vector or list of + vectors of the length of the DataFrame to use a particular vector as + the join key instead of columns. + right_on : label or list, or array-like + Field names to join on in right DataFrame or vector/list of vectors per + left_on docs. + left_by : column name or list of column names + Group left DataFrame by group columns and merge piece by piece with + right DataFrame. Must be None if either left or right are a Series. + right_by : column name or list of column names + Group right DataFrame by group columns and merge piece by piece with + left DataFrame. Must be None if either left or right are a Series. + fill_method : {'ffill', None}, default None + Interpolation method for data. + suffixes : list-like, default is ("_x", "_y") + A length-2 sequence where each element is optionally a string + indicating the suffix to add to overlapping column names in + `left` and `right` respectively. Pass a value of `None` instead + of a string to indicate that the column name from `left` or + `right` should be left as-is, with no suffix. At least one of the + values must not be None. + + how : {'left', 'right', 'outer', 'inner'}, default 'outer' + * left: use only keys from left frame (SQL: left outer join) + * right: use only keys from right frame (SQL: right outer join) + * outer: use union of keys from both frames (SQL: full outer join) + * inner: use intersection of keys from both frames (SQL: inner join). + + Returns + ------- + DataFrame + The merged DataFrame output type will be the same as + 'left', if it is a subclass of DataFrame. + + See Also + -------- + merge : Merge with a database-style join. + merge_asof : Merge on nearest keys. + + Examples + -------- + >>> from pandas import merge_ordered + >>> df1 = pd.DataFrame( + ... { + ... "key": ["a", "c", "e", "a", "c", "e"], + ... "lvalue": [1, 2, 3, 1, 2, 3], + ... "group": ["a", "a", "a", "b", "b", "b"] + ... } + ... ) + >>> df1 + key lvalue group + 0 a 1 a + 1 c 2 a + 2 e 3 a + 3 a 1 b + 4 c 2 b + 5 e 3 b + + >>> df2 = pd.DataFrame({"key": ["b", "c", "d"], "rvalue": [1, 2, 3]}) + >>> df2 + key rvalue + 0 b 1 + 1 c 2 + 2 d 3 + + >>> merge_ordered(df1, df2, fill_method="ffill", left_by="group") + key lvalue group rvalue + 0 a 1 a NaN + 1 b 1 a 1.0 + 2 c 2 a 2.0 + 3 d 2 a 3.0 + 4 e 3 a 3.0 + 5 a 1 b NaN + 6 b 1 b 1.0 + 7 c 2 b 2.0 + 8 d 2 b 3.0 + 9 e 3 b 3.0 + """ + + def _merger(x, y) -> DataFrame: + # perform the ordered merge operation + op = _OrderedMerge( + x, + y, + on=on, + left_on=left_on, + right_on=right_on, + suffixes=suffixes, + fill_method=fill_method, + how=how, + ) + return op.get_result() + + if left_by is not None and right_by is not None: + raise ValueError("Can only group either left or right frames") + if left_by is not None: + if isinstance(left_by, str): + left_by = [left_by] + check = set(left_by).difference(left.columns) + if len(check) != 0: + raise KeyError(f"{check} not found in left columns") + result, _ = _groupby_and_merge(left_by, left, right, lambda x, y: _merger(x, y)) + elif right_by is not None: + if isinstance(right_by, str): + right_by = [right_by] + check = set(right_by).difference(right.columns) + if len(check) != 0: + raise KeyError(f"{check} not found in right columns") + result, _ = _groupby_and_merge( + right_by, right, left, lambda x, y: _merger(y, x) + ) + else: + result = _merger(left, right) + return result + + +def merge_asof( + left: DataFrame | Series, + right: DataFrame | Series, + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + by=None, + left_by=None, + right_by=None, + suffixes: Suffixes = ("_x", "_y"), + tolerance: int | Timedelta | None = None, + allow_exact_matches: bool = True, + direction: str = "backward", +) -> DataFrame: + """ + Perform a merge by key distance. + + This is similar to a left-join except that we match on nearest + key rather than equal keys. Both DataFrames must be sorted by the key. + + For each row in the left DataFrame: + + - A "backward" search selects the last row in the right DataFrame whose + 'on' key is less than or equal to the left's key. + + - A "forward" search selects the first row in the right DataFrame whose + 'on' key is greater than or equal to the left's key. + + - A "nearest" search selects the row in the right DataFrame whose 'on' + key is closest in absolute distance to the left's key. + + Optionally match on equivalent keys with 'by' before searching with 'on'. + + Parameters + ---------- + left : DataFrame or named Series + right : DataFrame or named Series + on : label + Field name to join on. Must be found in both DataFrames. + The data MUST be ordered. Furthermore this must be a numeric column, + such as datetimelike, integer, or float. On or left_on/right_on + must be given. + left_on : label + Field name to join on in left DataFrame. + right_on : label + Field name to join on in right DataFrame. + left_index : bool + Use the index of the left DataFrame as the join key. + right_index : bool + Use the index of the right DataFrame as the join key. + by : column name or list of column names + Match on these columns before performing merge operation. + left_by : column name + Field names to match on in the left DataFrame. + right_by : column name + Field names to match on in the right DataFrame. + suffixes : 2-length sequence (tuple, list, ...) + Suffix to apply to overlapping column names in the left and right + side, respectively. + tolerance : int or Timedelta, optional, default None + Select asof tolerance within this range; must be compatible + with the merge index. + allow_exact_matches : bool, default True + + - If True, allow matching with the same 'on' value + (i.e. less-than-or-equal-to / greater-than-or-equal-to) + - If False, don't match the same 'on' value + (i.e., strictly less-than / strictly greater-than). + + direction : 'backward' (default), 'forward', or 'nearest' + Whether to search for prior, subsequent, or closest matches. + + Returns + ------- + DataFrame + + See Also + -------- + merge : Merge with a database-style join. + merge_ordered : Merge with optional filling/interpolation. + + Examples + -------- + >>> left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + >>> left + a left_val + 0 1 a + 1 5 b + 2 10 c + + >>> right = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) + >>> right + a right_val + 0 1 1 + 1 2 2 + 2 3 3 + 3 6 6 + 4 7 7 + + >>> pd.merge_asof(left, right, on="a") + a left_val right_val + 0 1 a 1 + 1 5 b 3 + 2 10 c 7 + + >>> pd.merge_asof(left, right, on="a", allow_exact_matches=False) + a left_val right_val + 0 1 a NaN + 1 5 b 3.0 + 2 10 c 7.0 + + >>> pd.merge_asof(left, right, on="a", direction="forward") + a left_val right_val + 0 1 a 1.0 + 1 5 b 6.0 + 2 10 c NaN + + >>> pd.merge_asof(left, right, on="a", direction="nearest") + a left_val right_val + 0 1 a 1 + 1 5 b 6 + 2 10 c 7 + + We can use indexed DataFrames as well. + + >>> left = pd.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10]) + >>> left + left_val + 1 a + 5 b + 10 c + + >>> right = pd.DataFrame({"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7]) + >>> right + right_val + 1 1 + 2 2 + 3 3 + 6 6 + 7 7 + + >>> pd.merge_asof(left, right, left_index=True, right_index=True) + left_val right_val + 1 a 1 + 5 b 3 + 10 c 7 + + Here is a real-world times-series example + + >>> quotes = pd.DataFrame( + ... { + ... "time": [ + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.030"), + ... pd.Timestamp("2016-05-25 13:30:00.041"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.049"), + ... pd.Timestamp("2016-05-25 13:30:00.072"), + ... pd.Timestamp("2016-05-25 13:30:00.075") + ... ], + ... "ticker": [ + ... "GOOG", + ... "MSFT", + ... "MSFT", + ... "MSFT", + ... "GOOG", + ... "AAPL", + ... "GOOG", + ... "MSFT" + ... ], + ... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + ... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03] + ... } + ... ) + >>> quotes + time ticker bid ask + 0 2016-05-25 13:30:00.023 GOOG 720.50 720.93 + 1 2016-05-25 13:30:00.023 MSFT 51.95 51.96 + 2 2016-05-25 13:30:00.030 MSFT 51.97 51.98 + 3 2016-05-25 13:30:00.041 MSFT 51.99 52.00 + 4 2016-05-25 13:30:00.048 GOOG 720.50 720.93 + 5 2016-05-25 13:30:00.049 AAPL 97.99 98.01 + 6 2016-05-25 13:30:00.072 GOOG 720.50 720.88 + 7 2016-05-25 13:30:00.075 MSFT 52.01 52.03 + + >>> trades = pd.DataFrame( + ... { + ... "time": [ + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.038"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.048") + ... ], + ... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + ... "price": [51.95, 51.95, 720.77, 720.92, 98.0], + ... "quantity": [75, 155, 100, 100, 100] + ... } + ... ) + >>> trades + time ticker price quantity + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 + 2 2016-05-25 13:30:00.048 GOOG 720.77 100 + 3 2016-05-25 13:30:00.048 GOOG 720.92 100 + 4 2016-05-25 13:30:00.048 AAPL 98.00 100 + + By default we are taking the asof of the quotes + + >>> pd.merge_asof(trades, quotes, on="time", by="ticker") + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 + 3 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 + 4 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + + We only asof within 2ms between the quote time and the trade time + + >>> pd.merge_asof( + ... trades, quotes, on="time", by="ticker", tolerance=pd.Timedelta("2ms") + ... ) + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 NaN NaN + 2 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 + 3 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 + 4 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + + We only asof within 10ms between the quote time and the trade time + and we exclude exact matches on time. However *prior* data will + propagate forward + + >>> pd.merge_asof( + ... trades, + ... quotes, + ... on="time", + ... by="ticker", + ... tolerance=pd.Timedelta("10ms"), + ... allow_exact_matches=False + ... ) + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 NaN NaN + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.048 GOOG 720.77 100 NaN NaN + 3 2016-05-25 13:30:00.048 GOOG 720.92 100 NaN NaN + 4 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + """ + op = _AsOfMerge( + left, + right, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + by=by, + left_by=left_by, + right_by=right_by, + suffixes=suffixes, + how="asof", + tolerance=tolerance, + allow_exact_matches=allow_exact_matches, + direction=direction, + ) + return op.get_result() + + +# TODO: transformations?? +# TODO: only copy DataFrames when modification necessary +class _MergeOperation: + """ + Perform a database (SQL) merge operation between two DataFrame or Series + objects using either columns as keys or their row indexes + """ + + _merge_type = "merge" + how: MergeHow | Literal["asof"] + on: IndexLabel | None + # left_on/right_on may be None when passed, but in validate_specification + # get replaced with non-None. + left_on: Sequence[Hashable | AnyArrayLike] + right_on: Sequence[Hashable | AnyArrayLike] + left_index: bool + right_index: bool + sort: bool + suffixes: Suffixes + copy: bool + indicator: str | bool + validate: str | None + join_names: list[Hashable] + right_join_keys: list[ArrayLike] + left_join_keys: list[ArrayLike] + + def __init__( + self, + left: DataFrame | Series, + right: DataFrame | Series, + how: MergeHow | Literal["asof"] = "inner", + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = True, + suffixes: Suffixes = ("_x", "_y"), + indicator: str | bool = False, + validate: str | None = None, + ) -> None: + _left = _validate_operand(left) + _right = _validate_operand(right) + self.left = self.orig_left = _left + self.right = self.orig_right = _right + self.how = how + + self.on = com.maybe_make_list(on) + + self.suffixes = suffixes + self.sort = sort + + self.left_index = left_index + self.right_index = right_index + + self.indicator = indicator + + if not is_bool(left_index): + raise ValueError( + f"left_index parameter must be of type bool, not {type(left_index)}" + ) + if not is_bool(right_index): + raise ValueError( + f"right_index parameter must be of type bool, not {type(right_index)}" + ) + + # GH 40993: raise when merging between different levels; enforced in 2.0 + if _left.columns.nlevels != _right.columns.nlevels: + msg = ( + "Not allowed to merge between different levels. " + f"({_left.columns.nlevels} levels on the left, " + f"{_right.columns.nlevels} on the right)" + ) + raise MergeError(msg) + + self.left_on, self.right_on = self._validate_left_right_on(left_on, right_on) + + ( + self.left_join_keys, + self.right_join_keys, + self.join_names, + left_drop, + right_drop, + ) = self._get_merge_keys() + + if left_drop: + self.left = self.left._drop_labels_or_levels(left_drop) + + if right_drop: + self.right = self.right._drop_labels_or_levels(right_drop) + + self._maybe_require_matching_dtypes(self.left_join_keys, self.right_join_keys) + self._validate_tolerance(self.left_join_keys) + + # validate the merge keys dtypes. We may need to coerce + # to avoid incompatible dtypes + self._maybe_coerce_merge_keys() + + # If argument passed to validate, + # check if columns specified as unique + # are in fact unique. + if validate is not None: + self._validate_validate_kwd(validate) + + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: + # Overridden by AsOfMerge + pass + + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: + # Overridden by AsOfMerge + pass + + @final + def _reindex_and_concat( + self, + join_index: Index, + left_indexer: npt.NDArray[np.intp] | None, + right_indexer: npt.NDArray[np.intp] | None, + copy: bool | None, + ) -> DataFrame: + """ + reindex along index and concat along columns. + """ + # Take views so we do not alter the originals + left = self.left[:] + right = self.right[:] + + llabels, rlabels = _items_overlap_with_suffix( + self.left._info_axis, self.right._info_axis, self.suffixes + ) + + if left_indexer is not None and not is_range_indexer(left_indexer, len(left)): + # Pinning the index here (and in the right code just below) is not + # necessary, but makes the `.take` more performant if we have e.g. + # a MultiIndex for left.index. + lmgr = left._mgr.reindex_indexer( + join_index, + left_indexer, + axis=1, + copy=False, + only_slice=True, + allow_dups=True, + use_na_proxy=True, + ) + left = left._constructor_from_mgr(lmgr, axes=lmgr.axes) + left.index = join_index + + if right_indexer is not None and not is_range_indexer( + right_indexer, len(right) + ): + rmgr = right._mgr.reindex_indexer( + join_index, + right_indexer, + axis=1, + copy=False, + only_slice=True, + allow_dups=True, + use_na_proxy=True, + ) + right = right._constructor_from_mgr(rmgr, axes=rmgr.axes) + right.index = join_index + + from pandas import concat + + left.columns = llabels + right.columns = rlabels + result = concat([left, right], axis=1, copy=copy) + return result + + def get_result(self, copy: bool | None = True) -> DataFrame: + if self.indicator: + self.left, self.right = self._indicator_pre_merge(self.left, self.right) + + join_index, left_indexer, right_indexer = self._get_join_info() + + result = self._reindex_and_concat( + join_index, left_indexer, right_indexer, copy=copy + ) + result = result.__finalize__(self, method=self._merge_type) + + if self.indicator: + result = self._indicator_post_merge(result) + + self._maybe_add_join_keys(result, left_indexer, right_indexer) + + self._maybe_restore_index_levels(result) + + return result.__finalize__(self, method="merge") + + @final + @cache_readonly + def _indicator_name(self) -> str | None: + if isinstance(self.indicator, str): + return self.indicator + elif isinstance(self.indicator, bool): + return "_merge" if self.indicator else None + else: + raise ValueError( + "indicator option can only accept boolean or string arguments" + ) + + @final + def _indicator_pre_merge( + self, left: DataFrame, right: DataFrame + ) -> tuple[DataFrame, DataFrame]: + columns = left.columns.union(right.columns) + + for i in ["_left_indicator", "_right_indicator"]: + if i in columns: + raise ValueError( + "Cannot use `indicator=True` option when " + f"data contains a column named {i}" + ) + if self._indicator_name in columns: + raise ValueError( + "Cannot use name of an existing column for indicator column" + ) + + left = left.copy() + right = right.copy() + + left["_left_indicator"] = 1 + left["_left_indicator"] = left["_left_indicator"].astype("int8") + + right["_right_indicator"] = 2 + right["_right_indicator"] = right["_right_indicator"].astype("int8") + + return left, right + + @final + def _indicator_post_merge(self, result: DataFrame) -> DataFrame: + result["_left_indicator"] = result["_left_indicator"].fillna(0) + result["_right_indicator"] = result["_right_indicator"].fillna(0) + + result[self._indicator_name] = Categorical( + (result["_left_indicator"] + result["_right_indicator"]), + categories=[1, 2, 3], + ) + result[self._indicator_name] = result[ + self._indicator_name + ].cat.rename_categories(["left_only", "right_only", "both"]) + + result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1) + return result + + @final + def _maybe_restore_index_levels(self, result: DataFrame) -> None: + """ + Restore index levels specified as `on` parameters + + Here we check for cases where `self.left_on` and `self.right_on` pairs + each reference an index level in their respective DataFrames. The + joined columns corresponding to these pairs are then restored to the + index of `result`. + + **Note:** This method has side effects. It modifies `result` in-place + + Parameters + ---------- + result: DataFrame + merge result + + Returns + ------- + None + """ + names_to_restore = [] + for name, left_key, right_key in zip( + self.join_names, self.left_on, self.right_on + ): + if ( + # Argument 1 to "_is_level_reference" of "NDFrame" has incompatible + # type "Union[Hashable, ExtensionArray, Index, Series]"; expected + # "Hashable" + self.orig_left._is_level_reference(left_key) # type: ignore[arg-type] + # Argument 1 to "_is_level_reference" of "NDFrame" has incompatible + # type "Union[Hashable, ExtensionArray, Index, Series]"; expected + # "Hashable" + and self.orig_right._is_level_reference( + right_key # type: ignore[arg-type] + ) + and left_key == right_key + and name not in result.index.names + ): + names_to_restore.append(name) + + if names_to_restore: + result.set_index(names_to_restore, inplace=True) + + @final + def _maybe_add_join_keys( + self, + result: DataFrame, + left_indexer: npt.NDArray[np.intp] | None, + right_indexer: npt.NDArray[np.intp] | None, + ) -> None: + left_has_missing = None + right_has_missing = None + + assert all(isinstance(x, _known) for x in self.left_join_keys) + + keys = zip(self.join_names, self.left_on, self.right_on) + for i, (name, lname, rname) in enumerate(keys): + if not _should_fill(lname, rname): + continue + + take_left, take_right = None, None + + if name in result: + if left_indexer is not None and right_indexer is not None: + if name in self.left: + if left_has_missing is None: + left_has_missing = (left_indexer == -1).any() + + if left_has_missing: + take_right = self.right_join_keys[i] + + if result[name].dtype != self.left[name].dtype: + take_left = self.left[name]._values + + elif name in self.right: + if right_has_missing is None: + right_has_missing = (right_indexer == -1).any() + + if right_has_missing: + take_left = self.left_join_keys[i] + + if result[name].dtype != self.right[name].dtype: + take_right = self.right[name]._values + + elif left_indexer is not None: + take_left = self.left_join_keys[i] + take_right = self.right_join_keys[i] + + if take_left is not None or take_right is not None: + if take_left is None: + lvals = result[name]._values + else: + # TODO: can we pin down take_left's type earlier? + take_left = extract_array(take_left, extract_numpy=True) + lfill = na_value_for_dtype(take_left.dtype) + lvals = algos.take_nd(take_left, left_indexer, fill_value=lfill) + + if take_right is None: + rvals = result[name]._values + else: + # TODO: can we pin down take_right's type earlier? + taker = extract_array(take_right, extract_numpy=True) + rfill = na_value_for_dtype(taker.dtype) + rvals = algos.take_nd(taker, right_indexer, fill_value=rfill) + + # if we have an all missing left_indexer + # make sure to just use the right values or vice-versa + mask_left = left_indexer == -1 + # error: Item "bool" of "Union[Any, bool]" has no attribute "all" + if mask_left.all(): # type: ignore[union-attr] + key_col = Index(rvals) + result_dtype = rvals.dtype + elif right_indexer is not None and (right_indexer == -1).all(): + key_col = Index(lvals) + result_dtype = lvals.dtype + else: + key_col = Index(lvals).where(~mask_left, rvals) + result_dtype = find_common_type([lvals.dtype, rvals.dtype]) + if ( + lvals.dtype.kind == "M" + and rvals.dtype.kind == "M" + and result_dtype.kind == "O" + ): + # TODO(non-nano) Workaround for common_type not dealing + # with different resolutions + result_dtype = key_col.dtype + + if result._is_label_reference(name): + result[name] = result._constructor_sliced( + key_col, dtype=result_dtype, index=result.index + ) + elif result._is_level_reference(name): + if isinstance(result.index, MultiIndex): + key_col.name = name + idx_list = [ + result.index.get_level_values(level_name) + if level_name != name + else key_col + for level_name in result.index.names + ] + + result.set_index(idx_list, inplace=True) + else: + result.index = Index(key_col, name=name) + else: + result.insert(i, name or f"key_{i}", key_col) + + def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + """return the join indexers""" + return get_join_indexers( + self.left_join_keys, self.right_join_keys, sort=self.sort, how=self.how + ) + + @final + def _get_join_info( + self, + ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: + # make mypy happy + assert self.how != "cross" + left_ax = self.left.index + right_ax = self.right.index + + if self.left_index and self.right_index and self.how != "asof": + join_index, left_indexer, right_indexer = left_ax.join( + right_ax, how=self.how, return_indexers=True, sort=self.sort + ) + + elif self.right_index and self.how == "left": + join_index, left_indexer, right_indexer = _left_join_on_index( + left_ax, right_ax, self.left_join_keys, sort=self.sort + ) + + elif self.left_index and self.how == "right": + join_index, right_indexer, left_indexer = _left_join_on_index( + right_ax, left_ax, self.right_join_keys, sort=self.sort + ) + else: + (left_indexer, right_indexer) = self._get_join_indexers() + + if self.right_index: + if len(self.left) > 0: + join_index = self._create_join_index( + left_ax, + right_ax, + left_indexer, + how="right", + ) + else: + join_index = right_ax.take(right_indexer) + elif self.left_index: + if self.how == "asof": + # GH#33463 asof should always behave like a left merge + join_index = self._create_join_index( + left_ax, + right_ax, + left_indexer, + how="left", + ) + + elif len(self.right) > 0: + join_index = self._create_join_index( + right_ax, + left_ax, + right_indexer, + how="left", + ) + else: + join_index = left_ax.take(left_indexer) + else: + join_index = default_index(len(left_indexer)) + + return join_index, left_indexer, right_indexer + + @final + def _create_join_index( + self, + index: Index, + other_index: Index, + indexer: npt.NDArray[np.intp], + how: JoinHow = "left", + ) -> Index: + """ + Create a join index by rearranging one index to match another + + Parameters + ---------- + index : Index being rearranged + other_index : Index used to supply values not found in index + indexer : np.ndarray[np.intp] how to rearrange index + how : str + Replacement is only necessary if indexer based on other_index. + + Returns + ------- + Index + """ + if self.how in (how, "outer") and not isinstance(other_index, MultiIndex): + # if final index requires values in other_index but not target + # index, indexer may hold missing (-1) values, causing Index.take + # to take the final value in target index. So, we set the last + # element to be the desired fill value. We do not use allow_fill + # and fill_value because it throws a ValueError on integer indices + mask = indexer == -1 + if np.any(mask): + fill_value = na_value_for_dtype(index.dtype, compat=False) + index = index.append(Index([fill_value])) + return index.take(indexer) + + @final + def _get_merge_keys( + self, + ) -> tuple[ + list[ArrayLike], + list[ArrayLike], + list[Hashable], + list[Hashable], + list[Hashable], + ]: + """ + Returns + ------- + left_keys, right_keys, join_names, left_drop, right_drop + """ + left_keys: list[ArrayLike] = [] + right_keys: list[ArrayLike] = [] + join_names: list[Hashable] = [] + right_drop: list[Hashable] = [] + left_drop: list[Hashable] = [] + + left, right = self.left, self.right + + is_lkey = lambda x: isinstance(x, _known) and len(x) == len(left) + is_rkey = lambda x: isinstance(x, _known) and len(x) == len(right) + + # Note that pd.merge_asof() has separate 'on' and 'by' parameters. A + # user could, for example, request 'left_index' and 'left_by'. In a + # regular pd.merge(), users cannot specify both 'left_index' and + # 'left_on'. (Instead, users have a MultiIndex). That means the + # self.left_on in this function is always empty in a pd.merge(), but + # a pd.merge_asof(left_index=True, left_by=...) will result in a + # self.left_on array with a None in the middle of it. This requires + # a work-around as designated in the code below. + # See _validate_left_right_on() for where this happens. + + # ugh, spaghetti re #733 + if _any(self.left_on) and _any(self.right_on): + for lk, rk in zip(self.left_on, self.right_on): + lk = extract_array(lk, extract_numpy=True) + rk = extract_array(rk, extract_numpy=True) + if is_lkey(lk): + lk = cast(ArrayLike, lk) + left_keys.append(lk) + if is_rkey(rk): + rk = cast(ArrayLike, rk) + right_keys.append(rk) + join_names.append(None) # what to do? + else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + rk = cast(Hashable, rk) + if rk is not None: + right_keys.append(right._get_label_or_level_values(rk)) + join_names.append(rk) + else: + # work-around for merge_asof(right_index=True) + right_keys.append(right.index._values) + join_names.append(right.index.name) + else: + if not is_rkey(rk): + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + rk = cast(Hashable, rk) + if rk is not None: + right_keys.append(right._get_label_or_level_values(rk)) + else: + # work-around for merge_asof(right_index=True) + right_keys.append(right.index._values) + if lk is not None and lk == rk: # FIXME: what about other NAs? + # avoid key upcast in corner case (length-0) + lk = cast(Hashable, lk) + if len(left) > 0: + right_drop.append(rk) + else: + left_drop.append(lk) + else: + rk = cast(ArrayLike, rk) + right_keys.append(rk) + if lk is not None: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + lk = cast(Hashable, lk) + left_keys.append(left._get_label_or_level_values(lk)) + join_names.append(lk) + else: + # work-around for merge_asof(left_index=True) + left_keys.append(left.index._values) + join_names.append(left.index.name) + elif _any(self.left_on): + for k in self.left_on: + if is_lkey(k): + k = extract_array(k, extract_numpy=True) + k = cast(ArrayLike, k) + left_keys.append(k) + join_names.append(None) + else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + k = cast(Hashable, k) + left_keys.append(left._get_label_or_level_values(k)) + join_names.append(k) + if isinstance(self.right.index, MultiIndex): + right_keys = [ + lev._values.take(lev_codes) + for lev, lev_codes in zip( + self.right.index.levels, self.right.index.codes + ) + ] + else: + right_keys = [self.right.index._values] + elif _any(self.right_on): + for k in self.right_on: + k = extract_array(k, extract_numpy=True) + if is_rkey(k): + k = cast(ArrayLike, k) + right_keys.append(k) + join_names.append(None) + else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + k = cast(Hashable, k) + right_keys.append(right._get_label_or_level_values(k)) + join_names.append(k) + if isinstance(self.left.index, MultiIndex): + left_keys = [ + lev._values.take(lev_codes) + for lev, lev_codes in zip( + self.left.index.levels, self.left.index.codes + ) + ] + else: + left_keys = [self.left.index._values] + + return left_keys, right_keys, join_names, left_drop, right_drop + + @final + def _maybe_coerce_merge_keys(self) -> None: + # we have valid merges but we may have to further + # coerce these if they are originally incompatible types + # + # for example if these are categorical, but are not dtype_equal + # or if we have object and integer dtypes + + for lk, rk, name in zip( + self.left_join_keys, self.right_join_keys, self.join_names + ): + if (len(lk) and not len(rk)) or (not len(lk) and len(rk)): + continue + + lk = extract_array(lk, extract_numpy=True) + rk = extract_array(rk, extract_numpy=True) + + lk_is_cat = isinstance(lk.dtype, CategoricalDtype) + rk_is_cat = isinstance(rk.dtype, CategoricalDtype) + lk_is_object = is_object_dtype(lk.dtype) + rk_is_object = is_object_dtype(rk.dtype) + + # if either left or right is a categorical + # then the must match exactly in categories & ordered + if lk_is_cat and rk_is_cat: + lk = cast(Categorical, lk) + rk = cast(Categorical, rk) + if lk._categories_match_up_to_permutation(rk): + continue + + elif lk_is_cat or rk_is_cat: + pass + + elif lk.dtype == rk.dtype: + continue + + msg = ( + f"You are trying to merge on {lk.dtype} and {rk.dtype} columns " + f"for key '{name}'. If you wish to proceed you should use pd.concat" + ) + + # if we are numeric, then allow differing + # kinds to proceed, eg. int64 and int8, int and float + # further if we are object, but we infer to + # the same, then proceed + if is_numeric_dtype(lk.dtype) and is_numeric_dtype(rk.dtype): + if lk.dtype.kind == rk.dtype.kind: + continue + + # check whether ints and floats + if is_integer_dtype(rk.dtype) and is_float_dtype(lk.dtype): + # GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int + with np.errstate(invalid="ignore"): + # error: Argument 1 to "astype" of "ndarray" has incompatible + # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected + # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]" + casted = lk.astype(rk.dtype) # type: ignore[arg-type] + + mask = ~np.isnan(lk) + match = lk == casted + if not match[mask].all(): + warnings.warn( + "You are merging on int and float " + "columns where the float values " + "are not equal to their int representation.", + UserWarning, + stacklevel=find_stack_level(), + ) + continue + + if is_float_dtype(rk.dtype) and is_integer_dtype(lk.dtype): + # GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int + with np.errstate(invalid="ignore"): + # error: Argument 1 to "astype" of "ndarray" has incompatible + # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected + # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]" + casted = rk.astype(lk.dtype) # type: ignore[arg-type] + + mask = ~np.isnan(rk) + match = rk == casted + if not match[mask].all(): + warnings.warn( + "You are merging on int and float " + "columns where the float values " + "are not equal to their int representation.", + UserWarning, + stacklevel=find_stack_level(), + ) + continue + + # let's infer and see if we are ok + if lib.infer_dtype(lk, skipna=False) == lib.infer_dtype( + rk, skipna=False + ): + continue + + # Check if we are trying to merge on obviously + # incompatible dtypes GH 9780, GH 15800 + + # bool values are coerced to object + elif (lk_is_object and is_bool_dtype(rk.dtype)) or ( + is_bool_dtype(lk.dtype) and rk_is_object + ): + pass + + # object values are allowed to be merged + elif (lk_is_object and is_numeric_dtype(rk.dtype)) or ( + is_numeric_dtype(lk.dtype) and rk_is_object + ): + inferred_left = lib.infer_dtype(lk, skipna=False) + inferred_right = lib.infer_dtype(rk, skipna=False) + bool_types = ["integer", "mixed-integer", "boolean", "empty"] + string_types = ["string", "unicode", "mixed", "bytes", "empty"] + + # inferred bool + if inferred_left in bool_types and inferred_right in bool_types: + pass + + # unless we are merging non-string-like with string-like + elif ( + inferred_left in string_types and inferred_right not in string_types + ) or ( + inferred_right in string_types and inferred_left not in string_types + ): + raise ValueError(msg) + + # datetimelikes must match exactly + elif needs_i8_conversion(lk.dtype) and not needs_i8_conversion(rk.dtype): + raise ValueError(msg) + elif not needs_i8_conversion(lk.dtype) and needs_i8_conversion(rk.dtype): + raise ValueError(msg) + elif isinstance(lk.dtype, DatetimeTZDtype) and not isinstance( + rk.dtype, DatetimeTZDtype + ): + raise ValueError(msg) + elif not isinstance(lk.dtype, DatetimeTZDtype) and isinstance( + rk.dtype, DatetimeTZDtype + ): + raise ValueError(msg) + elif ( + isinstance(lk.dtype, DatetimeTZDtype) + and isinstance(rk.dtype, DatetimeTZDtype) + ) or (lk.dtype.kind == "M" and rk.dtype.kind == "M"): + # allows datetime with different resolutions + continue + + elif lk_is_object and rk_is_object: + continue + + # Houston, we have a problem! + # let's coerce to object if the dtypes aren't + # categorical, otherwise coerce to the category + # dtype. If we coerced categories to object, + # then we would lose type information on some + # columns, and end up trying to merge + # incompatible dtypes. See GH 16900. + if name in self.left.columns: + typ = cast(Categorical, lk).categories.dtype if lk_is_cat else object + self.left = self.left.copy() + self.left[name] = self.left[name].astype(typ) + if name in self.right.columns: + typ = cast(Categorical, rk).categories.dtype if rk_is_cat else object + self.right = self.right.copy() + self.right[name] = self.right[name].astype(typ) + + def _validate_left_right_on(self, left_on, right_on): + left_on = com.maybe_make_list(left_on) + right_on = com.maybe_make_list(right_on) + + # Hm, any way to make this logic less complicated?? + if self.on is None and left_on is None and right_on is None: + if self.left_index and self.right_index: + left_on, right_on = (), () + elif self.left_index: + raise MergeError("Must pass right_on or right_index=True") + elif self.right_index: + raise MergeError("Must pass left_on or left_index=True") + else: + # use the common columns + left_cols = self.left.columns + right_cols = self.right.columns + common_cols = left_cols.intersection(right_cols) + if len(common_cols) == 0: + raise MergeError( + "No common columns to perform merge on. " + f"Merge options: left_on={left_on}, " + f"right_on={right_on}, " + f"left_index={self.left_index}, " + f"right_index={self.right_index}" + ) + if ( + not left_cols.join(common_cols, how="inner").is_unique + or not right_cols.join(common_cols, how="inner").is_unique + ): + raise MergeError(f"Data columns not unique: {repr(common_cols)}") + left_on = right_on = common_cols + elif self.on is not None: + if left_on is not None or right_on is not None: + raise MergeError( + 'Can only pass argument "on" OR "left_on" ' + 'and "right_on", not a combination of both.' + ) + if self.left_index or self.right_index: + raise MergeError( + 'Can only pass argument "on" OR "left_index" ' + 'and "right_index", not a combination of both.' + ) + left_on = right_on = self.on + elif left_on is not None: + if self.left_index: + raise MergeError( + 'Can only pass argument "left_on" OR "left_index" not both.' + ) + if not self.right_index and right_on is None: + raise MergeError('Must pass "right_on" OR "right_index".') + n = len(left_on) + if self.right_index: + if len(left_on) != self.right.index.nlevels: + raise ValueError( + "len(left_on) must equal the number " + 'of levels in the index of "right"' + ) + right_on = [None] * n + elif right_on is not None: + if self.right_index: + raise MergeError( + 'Can only pass argument "right_on" OR "right_index" not both.' + ) + if not self.left_index and left_on is None: + raise MergeError('Must pass "left_on" OR "left_index".') + n = len(right_on) + if self.left_index: + if len(right_on) != self.left.index.nlevels: + raise ValueError( + "len(right_on) must equal the number " + 'of levels in the index of "left"' + ) + left_on = [None] * n + if len(right_on) != len(left_on): + raise ValueError("len(right_on) must equal len(left_on)") + + return left_on, right_on + + @final + def _validate_validate_kwd(self, validate: str) -> None: + # Check uniqueness of each + if self.left_index: + left_unique = self.orig_left.index.is_unique + else: + left_unique = MultiIndex.from_arrays(self.left_join_keys).is_unique + + if self.right_index: + right_unique = self.orig_right.index.is_unique + else: + right_unique = MultiIndex.from_arrays(self.right_join_keys).is_unique + + # Check data integrity + if validate in ["one_to_one", "1:1"]: + if not left_unique and not right_unique: + raise MergeError( + "Merge keys are not unique in either left " + "or right dataset; not a one-to-one merge" + ) + if not left_unique: + raise MergeError( + "Merge keys are not unique in left dataset; not a one-to-one merge" + ) + if not right_unique: + raise MergeError( + "Merge keys are not unique in right dataset; not a one-to-one merge" + ) + + elif validate in ["one_to_many", "1:m"]: + if not left_unique: + raise MergeError( + "Merge keys are not unique in left dataset; not a one-to-many merge" + ) + + elif validate in ["many_to_one", "m:1"]: + if not right_unique: + raise MergeError( + "Merge keys are not unique in right dataset; " + "not a many-to-one merge" + ) + + elif validate in ["many_to_many", "m:m"]: + pass + + else: + raise ValueError( + f'"{validate}" is not a valid argument. ' + "Valid arguments are:\n" + '- "1:1"\n' + '- "1:m"\n' + '- "m:1"\n' + '- "m:m"\n' + '- "one_to_one"\n' + '- "one_to_many"\n' + '- "many_to_one"\n' + '- "many_to_many"' + ) + + +def get_join_indexers( + left_keys: list[ArrayLike], + right_keys: list[ArrayLike], + sort: bool = False, + how: MergeHow | Literal["asof"] = "inner", +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + """ + + Parameters + ---------- + left_keys : list[ndarray, ExtensionArray, Index, Series] + right_keys : list[ndarray, ExtensionArray, Index, Series] + sort : bool, default False + how : {'inner', 'outer', 'left', 'right'}, default 'inner' + + Returns + ------- + np.ndarray[np.intp] + Indexer into the left_keys. + np.ndarray[np.intp] + Indexer into the right_keys. + """ + assert len(left_keys) == len( + right_keys + ), "left_keys and right_keys must be the same length" + + # fast-path for empty left/right + left_n = len(left_keys[0]) + right_n = len(right_keys[0]) + if left_n == 0: + if how in ["left", "inner", "cross"]: + return _get_empty_indexer() + elif not sort and how in ["right", "outer"]: + return _get_no_sort_one_missing_indexer(right_n, True) + elif right_n == 0: + if how in ["right", "inner", "cross"]: + return _get_empty_indexer() + elif not sort and how in ["left", "outer"]: + return _get_no_sort_one_missing_indexer(left_n, False) + + # get left & right join labels and num. of levels at each location + mapped = ( + _factorize_keys(left_keys[n], right_keys[n], sort=sort, how=how) + for n in range(len(left_keys)) + ) + zipped = zip(*mapped) + llab, rlab, shape = (list(x) for x in zipped) + + # get flat i8 keys from label lists + lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort) + + # factorize keys to a dense i8 space + # `count` is the num. of unique keys + # set(lkey) | set(rkey) == range(count) + + lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort, how=how) + # preserve left frame order if how == 'left' and sort == False + kwargs = {} + if how in ("left", "right"): + kwargs["sort"] = sort + join_func = { + "inner": libjoin.inner_join, + "left": libjoin.left_outer_join, + "right": lambda x, y, count, **kwargs: libjoin.left_outer_join( + y, x, count, **kwargs + )[::-1], + "outer": libjoin.full_outer_join, + }[how] + + # error: Cannot call function of unknown type + return join_func(lkey, rkey, count, **kwargs) # type: ignore[operator] + + +def restore_dropped_levels_multijoin( + left: MultiIndex, + right: MultiIndex, + dropped_level_names, + join_index: Index, + lindexer: npt.NDArray[np.intp], + rindexer: npt.NDArray[np.intp], +) -> tuple[list[Index], npt.NDArray[np.intp], list[Hashable]]: + """ + *this is an internal non-public method* + + Returns the levels, labels and names of a multi-index to multi-index join. + Depending on the type of join, this method restores the appropriate + dropped levels of the joined multi-index. + The method relies on lindexer, rindexer which hold the index positions of + left and right, where a join was feasible + + Parameters + ---------- + left : MultiIndex + left index + right : MultiIndex + right index + dropped_level_names : str array + list of non-common level names + join_index : Index + the index of the join between the + common levels of left and right + lindexer : np.ndarray[np.intp] + left indexer + rindexer : np.ndarray[np.intp] + right indexer + + Returns + ------- + levels : list of Index + levels of combined multiindexes + labels : np.ndarray[np.intp] + labels of combined multiindexes + names : List[Hashable] + names of combined multiindex levels + + """ + + def _convert_to_multiindex(index: Index) -> MultiIndex: + if isinstance(index, MultiIndex): + return index + else: + return MultiIndex.from_arrays([index._values], names=[index.name]) + + # For multi-multi joins with one overlapping level, + # the returned index if of type Index + # Assure that join_index is of type MultiIndex + # so that dropped levels can be appended + join_index = _convert_to_multiindex(join_index) + + join_levels = join_index.levels + join_codes = join_index.codes + join_names = join_index.names + + # Iterate through the levels that must be restored + for dropped_level_name in dropped_level_names: + if dropped_level_name in left.names: + idx = left + indexer = lindexer + else: + idx = right + indexer = rindexer + + # The index of the level name to be restored + name_idx = idx.names.index(dropped_level_name) + + restore_levels = idx.levels[name_idx] + # Inject -1 in the codes list where a join was not possible + # IOW indexer[i]=-1 + codes = idx.codes[name_idx] + if indexer is None: + restore_codes = codes + else: + restore_codes = algos.take_nd(codes, indexer, fill_value=-1) + + # error: Cannot determine type of "__add__" + join_levels = join_levels + [restore_levels] # type: ignore[has-type] + join_codes = join_codes + [restore_codes] + join_names = join_names + [dropped_level_name] + + return join_levels, join_codes, join_names + + +class _OrderedMerge(_MergeOperation): + _merge_type = "ordered_merge" + + def __init__( + self, + left: DataFrame | Series, + right: DataFrame | Series, + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + suffixes: Suffixes = ("_x", "_y"), + fill_method: str | None = None, + how: JoinHow | Literal["asof"] = "outer", + ) -> None: + self.fill_method = fill_method + _MergeOperation.__init__( + self, + left, + right, + on=on, + left_on=left_on, + left_index=left_index, + right_index=right_index, + right_on=right_on, + how=how, + suffixes=suffixes, + sort=True, # factorize sorts + ) + + def get_result(self, copy: bool | None = True) -> DataFrame: + join_index, left_indexer, right_indexer = self._get_join_info() + + left_join_indexer: npt.NDArray[np.intp] | None + right_join_indexer: npt.NDArray[np.intp] | None + + if self.fill_method == "ffill": + if left_indexer is None: + raise TypeError("left_indexer cannot be None") + left_indexer = cast("npt.NDArray[np.intp]", left_indexer) + right_indexer = cast("npt.NDArray[np.intp]", right_indexer) + left_join_indexer = libjoin.ffill_indexer(left_indexer) + right_join_indexer = libjoin.ffill_indexer(right_indexer) + else: + left_join_indexer = left_indexer + right_join_indexer = right_indexer + + result = self._reindex_and_concat( + join_index, left_join_indexer, right_join_indexer, copy=copy + ) + self._maybe_add_join_keys(result, left_indexer, right_indexer) + + return result + + +def _asof_by_function(direction: str): + name = f"asof_join_{direction}_on_X_by_Y" + return getattr(libjoin, name, None) + + +class _AsOfMerge(_OrderedMerge): + _merge_type = "asof_merge" + + def __init__( + self, + left: DataFrame | Series, + right: DataFrame | Series, + on: IndexLabel | None = None, + left_on: IndexLabel | None = None, + right_on: IndexLabel | None = None, + left_index: bool = False, + right_index: bool = False, + by=None, + left_by=None, + right_by=None, + suffixes: Suffixes = ("_x", "_y"), + how: Literal["asof"] = "asof", + tolerance=None, + allow_exact_matches: bool = True, + direction: str = "backward", + ) -> None: + self.by = by + self.left_by = left_by + self.right_by = right_by + self.tolerance = tolerance + self.allow_exact_matches = allow_exact_matches + self.direction = direction + + # check 'direction' is valid + if self.direction not in ["backward", "forward", "nearest"]: + raise MergeError(f"direction invalid: {self.direction}") + + # validate allow_exact_matches + if not is_bool(self.allow_exact_matches): + msg = ( + "allow_exact_matches must be boolean, " + f"passed {self.allow_exact_matches}" + ) + raise MergeError(msg) + + _OrderedMerge.__init__( + self, + left, + right, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + how=how, + suffixes=suffixes, + fill_method=None, + ) + + def _validate_left_right_on(self, left_on, right_on): + left_on, right_on = super()._validate_left_right_on(left_on, right_on) + + # we only allow on to be a single item for on + if len(left_on) != 1 and not self.left_index: + raise MergeError("can only asof on a key for left") + + if len(right_on) != 1 and not self.right_index: + raise MergeError("can only asof on a key for right") + + if self.left_index and isinstance(self.left.index, MultiIndex): + raise MergeError("left can only have one index") + + if self.right_index and isinstance(self.right.index, MultiIndex): + raise MergeError("right can only have one index") + + # set 'by' columns + if self.by is not None: + if self.left_by is not None or self.right_by is not None: + raise MergeError("Can only pass by OR left_by and right_by") + self.left_by = self.right_by = self.by + if self.left_by is None and self.right_by is not None: + raise MergeError("missing left_by") + if self.left_by is not None and self.right_by is None: + raise MergeError("missing right_by") + + # GH#29130 Check that merge keys do not have dtype object + if not self.left_index: + left_on_0 = left_on[0] + if isinstance(left_on_0, _known): + lo_dtype = left_on_0.dtype + else: + lo_dtype = ( + self.left._get_label_or_level_values(left_on_0).dtype + if left_on_0 in self.left.columns + else self.left.index.get_level_values(left_on_0) + ) + else: + lo_dtype = self.left.index.dtype + + if not self.right_index: + right_on_0 = right_on[0] + if isinstance(right_on_0, _known): + ro_dtype = right_on_0.dtype + else: + ro_dtype = ( + self.right._get_label_or_level_values(right_on_0).dtype + if right_on_0 in self.right.columns + else self.right.index.get_level_values(right_on_0) + ) + else: + ro_dtype = self.right.index.dtype + + if is_object_dtype(lo_dtype) or is_object_dtype(ro_dtype): + raise MergeError( + f"Incompatible merge dtype, {repr(ro_dtype)} and " + f"{repr(lo_dtype)}, both sides must have numeric dtype" + ) + + # add 'by' to our key-list so we can have it in the + # output as a key + if self.left_by is not None: + if not is_list_like(self.left_by): + self.left_by = [self.left_by] + if not is_list_like(self.right_by): + self.right_by = [self.right_by] + + if len(self.left_by) != len(self.right_by): + raise MergeError("left_by and right_by must be the same length") + + left_on = self.left_by + list(left_on) + right_on = self.right_by + list(right_on) + + return left_on, right_on + + def _maybe_require_matching_dtypes( + self, left_join_keys: list[ArrayLike], right_join_keys: list[ArrayLike] + ) -> None: + # TODO: why do we do this for AsOfMerge but not the others? + + def _check_dtype_match(left: ArrayLike, right: ArrayLike, i: int): + if left.dtype != right.dtype: + if isinstance(left.dtype, CategoricalDtype) and isinstance( + right.dtype, CategoricalDtype + ): + # The generic error message is confusing for categoricals. + # + # In this function, the join keys include both the original + # ones of the merge_asof() call, and also the keys passed + # to its by= argument. Unordered but equal categories + # are not supported for the former, but will fail + # later with a ValueError, so we don't *need* to check + # for them here. + msg = ( + f"incompatible merge keys [{i}] {repr(left.dtype)} and " + f"{repr(right.dtype)}, both sides category, but not equal ones" + ) + else: + msg = ( + f"incompatible merge keys [{i}] {repr(left.dtype)} and " + f"{repr(right.dtype)}, must be the same type" + ) + raise MergeError(msg) + + # validate index types are the same + for i, (lk, rk) in enumerate(zip(left_join_keys, right_join_keys)): + _check_dtype_match(lk, rk, i) + + if self.left_index: + lt = self.left.index._values + else: + lt = left_join_keys[-1] + + if self.right_index: + rt = self.right.index._values + else: + rt = right_join_keys[-1] + + _check_dtype_match(lt, rt, 0) + + def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: + # validate tolerance; datetime.timedelta or Timedelta if we have a DTI + if self.tolerance is not None: + if self.left_index: + lt = self.left.index._values + else: + lt = left_join_keys[-1] + + msg = ( + f"incompatible tolerance {self.tolerance}, must be compat " + f"with type {repr(lt.dtype)}" + ) + + if needs_i8_conversion(lt.dtype): + if not isinstance(self.tolerance, datetime.timedelta): + raise MergeError(msg) + if self.tolerance < Timedelta(0): + raise MergeError("tolerance must be positive") + + elif is_integer_dtype(lt.dtype): + if not is_integer(self.tolerance): + raise MergeError(msg) + if self.tolerance < 0: + raise MergeError("tolerance must be positive") + + elif is_float_dtype(lt.dtype): + if not is_number(self.tolerance): + raise MergeError(msg) + # error: Unsupported operand types for > ("int" and "Number") + if self.tolerance < 0: # type: ignore[operator] + raise MergeError("tolerance must be positive") + + else: + raise MergeError("key must be integer, timestamp or float") + + def _convert_values_for_libjoin( + self, values: AnyArrayLike, side: str + ) -> np.ndarray: + # we require sortedness and non-null values in the join keys + if not Index(values).is_monotonic_increasing: + if isna(values).any(): + raise ValueError(f"Merge keys contain null values on {side} side") + raise ValueError(f"{side} keys must be sorted") + + if isinstance(values, ArrowExtensionArray): + values = values._maybe_convert_datelike_array() + + if needs_i8_conversion(values.dtype): + values = values.view("i8") + + elif isinstance(values, BaseMaskedArray): + # we've verified above that no nulls exist + values = values._data + elif isinstance(values, ExtensionArray): + values = values.to_numpy() + + # error: Incompatible return value type (got "Union[ExtensionArray, + # Any, ndarray[Any, Any], ndarray[Any, dtype[Any]], Index, Series]", + # expected "ndarray[Any, Any]") + return values # type: ignore[return-value] + + def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + """return the join indexers""" + + def flip(xs: list[ArrayLike]) -> np.ndarray: + """unlike np.transpose, this returns an array of tuples""" + + def injection(obj: ArrayLike): + if not isinstance(obj.dtype, ExtensionDtype): + # ndarray + return obj + obj = extract_array(obj) + if isinstance(obj, NDArrayBackedExtensionArray): + # fastpath for e.g. dt64tz, categorical + return obj._ndarray + # FIXME: returning obj._values_for_argsort() here doesn't + # break in any existing test cases, but i (@jbrockmendel) + # am pretty sure it should! + # e.g. + # arr = pd.array([0, pd.NA, 255], dtype="UInt8") + # will have values_for_argsort (before GH#45434) + # np.array([0, 255, 255], dtype=np.uint8) + # and the non-injectivity should make a difference somehow + # shouldn't it? + return np.asarray(obj) + + xs = [injection(x) for x in xs] + labels = list(string.ascii_lowercase[: len(xs)]) + dtypes = [x.dtype for x in xs] + labeled_dtypes = list(zip(labels, dtypes)) + return np.array(list(zip(*xs)), labeled_dtypes) + + # values to compare + left_values = ( + self.left.index._values if self.left_index else self.left_join_keys[-1] + ) + right_values = ( + self.right.index._values if self.right_index else self.right_join_keys[-1] + ) + + # _maybe_require_matching_dtypes already checked for dtype matching + assert left_values.dtype == right_values.dtype + + tolerance = self.tolerance + if tolerance is not None: + # TODO: can we reuse a tolerance-conversion function from + # e.g. TimedeltaIndex? + if needs_i8_conversion(left_values.dtype): + tolerance = Timedelta(tolerance) + # TODO: we have no test cases with PeriodDtype here; probably + # need to adjust tolerance for that case. + if left_values.dtype.kind in "mM": + # Make sure the i8 representation for tolerance + # matches that for left_values/right_values. + lvs = ensure_wrapped_if_datetimelike(left_values) + tolerance = tolerance.as_unit(lvs.unit) + + tolerance = tolerance._value + + # initial type conversion as needed + left_values = self._convert_values_for_libjoin(left_values, "left") + right_values = self._convert_values_for_libjoin(right_values, "right") + + # a "by" parameter requires special handling + if self.left_by is not None: + # remove 'on' parameter from values if one existed + if self.left_index and self.right_index: + left_by_values = self.left_join_keys + right_by_values = self.right_join_keys + else: + left_by_values = self.left_join_keys[0:-1] + right_by_values = self.right_join_keys[0:-1] + + # get tuple representation of values if more than one + if len(left_by_values) == 1: + lbv = left_by_values[0] + rbv = right_by_values[0] + + # TODO: conversions for EAs that can be no-copy. + lbv = np.asarray(lbv) + rbv = np.asarray(rbv) + else: + # We get here with non-ndarrays in test_merge_by_col_tz_aware + # and test_merge_groupby_multiple_column_with_categorical_column + lbv = flip(left_by_values) + rbv = flip(right_by_values) + lbv = ensure_object(lbv) + rbv = ensure_object(rbv) + + # error: Incompatible types in assignment (expression has type + # "Union[ndarray[Any, dtype[Any]], ndarray[Any, dtype[object_]]]", + # variable has type "List[Union[Union[ExtensionArray, + # ndarray[Any, Any]], Index, Series]]") + right_by_values = rbv # type: ignore[assignment] + # error: Incompatible types in assignment (expression has type + # "Union[ndarray[Any, dtype[Any]], ndarray[Any, dtype[object_]]]", + # variable has type "List[Union[Union[ExtensionArray, + # ndarray[Any, Any]], Index, Series]]") + left_by_values = lbv # type: ignore[assignment] + + # choose appropriate function by type + func = _asof_by_function(self.direction) + return func( + left_values, + right_values, + left_by_values, + right_by_values, + self.allow_exact_matches, + tolerance, + ) + else: + # choose appropriate function by type + func = _asof_by_function(self.direction) + # TODO(cython3): + # Bug in beta1 preventing Cython from choosing + # right specialization when one fused memview is None + # Doesn't matter what type we choose + # (nothing happens anyways since it is None) + # GH 51640 + return func[f"{left_values.dtype}_t", object]( + left_values, + right_values, + None, + None, + self.allow_exact_matches, + tolerance, + False, + ) + + +def _get_multiindex_indexer( + join_keys: list[ArrayLike], index: MultiIndex, sort: bool +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + # left & right join labels and num. of levels at each location + mapped = ( + _factorize_keys(index.levels[n]._values, join_keys[n], sort=sort) + for n in range(index.nlevels) + ) + zipped = zip(*mapped) + rcodes, lcodes, shape = (list(x) for x in zipped) + if sort: + rcodes = list(map(np.take, rcodes, index.codes)) + else: + i8copy = lambda a: a.astype("i8", subok=False, copy=True) + rcodes = list(map(i8copy, index.codes)) + + # fix right labels if there were any nulls + for i, join_key in enumerate(join_keys): + mask = index.codes[i] == -1 + if mask.any(): + # check if there already was any nulls at this location + # if there was, it is factorized to `shape[i] - 1` + a = join_key[lcodes[i] == shape[i] - 1] + if a.size == 0 or not a[0] != a[0]: + shape[i] += 1 + + rcodes[i][mask] = shape[i] - 1 + + # get flat i8 join keys + lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort) + return lkey, rkey + + +def _get_empty_indexer() -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + """Return empty join indexers.""" + return ( + np.array([], dtype=np.intp), + np.array([], dtype=np.intp), + ) + + +def _get_no_sort_one_missing_indexer( + n: int, left_missing: bool +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + """ + Return join indexers where all of one side is selected without sorting + and none of the other side is selected. + + Parameters + ---------- + n : int + Length of indexers to create. + left_missing : bool + If True, the left indexer will contain only -1's. + If False, the right indexer will contain only -1's. + + Returns + ------- + np.ndarray[np.intp] + Left indexer + np.ndarray[np.intp] + Right indexer + """ + idx = np.arange(n, dtype=np.intp) + idx_missing = np.full(shape=n, fill_value=-1, dtype=np.intp) + if left_missing: + return idx_missing, idx + return idx, idx_missing + + +def _left_join_on_index( + left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False +) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]: + if isinstance(right_ax, MultiIndex): + lkey, rkey = _get_multiindex_indexer(join_keys, right_ax, sort=sort) + else: + # error: Incompatible types in assignment (expression has type + # "Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series]", + # variable has type "ndarray[Any, dtype[signedinteger[Any]]]") + lkey = join_keys[0] # type: ignore[assignment] + # error: Incompatible types in assignment (expression has type "Index", + # variable has type "ndarray[Any, dtype[signedinteger[Any]]]") + rkey = right_ax._values # type: ignore[assignment] + + left_key, right_key, count = _factorize_keys(lkey, rkey, sort=sort) + left_indexer, right_indexer = libjoin.left_outer_join( + left_key, right_key, count, sort=sort + ) + + if sort or len(left_ax) != len(left_indexer): + # if asked to sort or there are 1-to-many matches + join_index = left_ax.take(left_indexer) + return join_index, left_indexer, right_indexer + + # left frame preserves order & length of its index + return left_ax, None, right_indexer + + +def _factorize_keys( + lk: ArrayLike, + rk: ArrayLike, + sort: bool = True, + how: MergeHow | Literal["asof"] = "inner", +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]: + """ + Encode left and right keys as enumerated types. + + This is used to get the join indexers to be used when merging DataFrames. + + Parameters + ---------- + lk : ndarray, ExtensionArray + Left key. + rk : ndarray, ExtensionArray + Right key. + sort : bool, defaults to True + If True, the encoding is done such that the unique elements in the + keys are sorted. + how : {‘left’, ‘right’, ‘outer’, ‘inner’}, default ‘inner’ + Type of merge. + + Returns + ------- + np.ndarray[np.intp] + Left (resp. right if called with `key='right'`) labels, as enumerated type. + np.ndarray[np.intp] + Right (resp. left if called with `key='right'`) labels, as enumerated type. + int + Number of unique elements in union of left and right labels. + + See Also + -------- + merge : Merge DataFrame or named Series objects + with a database-style join. + algorithms.factorize : Encode the object as an enumerated type + or categorical variable. + + Examples + -------- + >>> lk = np.array(["a", "c", "b"]) + >>> rk = np.array(["a", "c"]) + + Here, the unique values are `'a', 'b', 'c'`. With the default + `sort=True`, the encoding will be `{0: 'a', 1: 'b', 2: 'c'}`: + + >>> pd.core.reshape.merge._factorize_keys(lk, rk) + (array([0, 2, 1]), array([0, 2]), 3) + + With the `sort=False`, the encoding will correspond to the order + in which the unique elements first appear: `{0: 'a', 1: 'c', 2: 'b'}`: + + >>> pd.core.reshape.merge._factorize_keys(lk, rk, sort=False) + (array([0, 1, 2]), array([0, 1]), 3) + """ + # TODO: if either is a RangeIndex, we can likely factorize more efficiently? + + if ( + isinstance(lk.dtype, DatetimeTZDtype) and isinstance(rk.dtype, DatetimeTZDtype) + ) or (lib.is_np_dtype(lk.dtype, "M") and lib.is_np_dtype(rk.dtype, "M")): + # Extract the ndarray (UTC-localized) values + # Note: we dont need the dtypes to match, as these can still be compared + lk, rk = cast("DatetimeArray", lk)._ensure_matching_resos(rk) + lk = cast("DatetimeArray", lk)._ndarray + rk = cast("DatetimeArray", rk)._ndarray + + elif ( + isinstance(lk.dtype, CategoricalDtype) + and isinstance(rk.dtype, CategoricalDtype) + and lk.dtype == rk.dtype + ): + assert isinstance(lk, Categorical) + assert isinstance(rk, Categorical) + # Cast rk to encoding so we can compare codes with lk + + rk = lk._encode_with_my_categories(rk) + + lk = ensure_int64(lk.codes) + rk = ensure_int64(rk.codes) + + elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype: + if not isinstance(lk, BaseMaskedArray) and not ( + # exclude arrow dtypes that would get cast to object + isinstance(lk.dtype, ArrowDtype) + and is_numeric_dtype(lk.dtype.numpy_dtype) + ): + lk, _ = lk._values_for_factorize() + + # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute + # "_values_for_factorize" + rk, _ = rk._values_for_factorize() # type: ignore[union-attr] + + if needs_i8_conversion(lk.dtype) and lk.dtype == rk.dtype: + # GH#23917 TODO: Needs tests for non-matching dtypes + # GH#23917 TODO: needs tests for case where lk is integer-dtype + # and rk is datetime-dtype + lk = np.asarray(lk, dtype=np.int64) + rk = np.asarray(rk, dtype=np.int64) + + klass, lk, rk = _convert_arrays_and_get_rizer_klass(lk, rk) + + rizer = klass(max(len(lk), len(rk))) + + if isinstance(lk, BaseMaskedArray): + assert isinstance(rk, BaseMaskedArray) + llab = rizer.factorize(lk._data, mask=lk._mask) + rlab = rizer.factorize(rk._data, mask=rk._mask) + elif isinstance(lk, ArrowExtensionArray): + assert isinstance(rk, ArrowExtensionArray) + # we can only get here with numeric dtypes + # TODO: Remove when we have a Factorizer for Arrow + llab = rizer.factorize( + lk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=lk.isna() + ) + rlab = rizer.factorize( + rk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=rk.isna() + ) + else: + # Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type + # "Union[ndarray[Any, dtype[signedinteger[_64Bit]]], + # ndarray[Any, dtype[object_]]]"; expected "ndarray[Any, dtype[object_]]" + llab = rizer.factorize(lk) # type: ignore[arg-type] + rlab = rizer.factorize(rk) # type: ignore[arg-type] + assert llab.dtype == np.dtype(np.intp), llab.dtype + assert rlab.dtype == np.dtype(np.intp), rlab.dtype + + count = rizer.get_count() + + if sort: + uniques = rizer.uniques.to_array() + llab, rlab = _sort_labels(uniques, llab, rlab) + + # NA group + lmask = llab == -1 + lany = lmask.any() + rmask = rlab == -1 + rany = rmask.any() + + if lany or rany: + if lany: + np.putmask(llab, lmask, count) + if rany: + np.putmask(rlab, rmask, count) + count += 1 + + if how == "right": + return rlab, llab, count + return llab, rlab, count + + +def _convert_arrays_and_get_rizer_klass( + lk: ArrayLike, rk: ArrayLike +) -> tuple[type[libhashtable.Factorizer], ArrayLike, ArrayLike]: + klass: type[libhashtable.Factorizer] + if is_numeric_dtype(lk.dtype): + if lk.dtype != rk.dtype: + dtype = find_common_type([lk.dtype, rk.dtype]) + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + if not isinstance(lk, ExtensionArray): + lk = cls._from_sequence(lk, dtype=dtype, copy=False) + else: + lk = lk.astype(dtype) + + if not isinstance(rk, ExtensionArray): + rk = cls._from_sequence(rk, dtype=dtype, copy=False) + else: + rk = rk.astype(dtype) + else: + lk = lk.astype(dtype) + rk = rk.astype(dtype) + if isinstance(lk, BaseMaskedArray): + # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]"; + # expected type "Type[object]" + klass = _factorizers[lk.dtype.type] # type: ignore[index] + elif isinstance(lk.dtype, ArrowDtype): + klass = _factorizers[lk.dtype.numpy_dtype.type] + else: + klass = _factorizers[lk.dtype.type] + + else: + klass = libhashtable.ObjectFactorizer + lk = ensure_object(lk) + rk = ensure_object(rk) + return klass, lk, rk + + +def _sort_labels( + uniques: np.ndarray, left: npt.NDArray[np.intp], right: npt.NDArray[np.intp] +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + llength = len(left) + labels = np.concatenate([left, right]) + + _, new_labels = algos.safe_sort(uniques, labels, use_na_sentinel=True) + new_left, new_right = new_labels[:llength], new_labels[llength:] + + return new_left, new_right + + +def _get_join_keys( + llab: list[npt.NDArray[np.int64 | np.intp]], + rlab: list[npt.NDArray[np.int64 | np.intp]], + shape: Shape, + sort: bool, +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: + # how many levels can be done without overflow + nlev = next( + lev + for lev in range(len(shape), 0, -1) + if not is_int64_overflow_possible(shape[:lev]) + ) + + # get keys for the first `nlev` levels + stride = np.prod(shape[1:nlev], dtype="i8") + lkey = stride * llab[0].astype("i8", subok=False, copy=False) + rkey = stride * rlab[0].astype("i8", subok=False, copy=False) + + for i in range(1, nlev): + with np.errstate(divide="ignore"): + stride //= shape[i] + lkey += llab[i] * stride + rkey += rlab[i] * stride + + if nlev == len(shape): # all done! + return lkey, rkey + + # densify current keys to avoid overflow + lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort) + + llab = [lkey] + llab[nlev:] + rlab = [rkey] + rlab[nlev:] + shape = (count,) + shape[nlev:] + + return _get_join_keys(llab, rlab, shape, sort) + + +def _should_fill(lname, rname) -> bool: + if not isinstance(lname, str) or not isinstance(rname, str): + return True + return lname == rname + + +def _any(x) -> bool: + return x is not None and com.any_not_none(*x) + + +def _validate_operand(obj: DataFrame | Series) -> DataFrame: + if isinstance(obj, ABCDataFrame): + return obj + elif isinstance(obj, ABCSeries): + if obj.name is None: + raise ValueError("Cannot merge a Series without a name") + return obj.to_frame() + else: + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(obj)} was passed" + ) + + +def _items_overlap_with_suffix( + left: Index, right: Index, suffixes: Suffixes +) -> tuple[Index, Index]: + """ + Suffixes type validation. + + If two indices overlap, add suffixes to overlapping entries. + + If corresponding suffix is empty, the entry is simply converted to string. + + """ + if not is_list_like(suffixes, allow_sets=False) or isinstance(suffixes, dict): + raise TypeError( + f"Passing 'suffixes' as a {type(suffixes)}, is not supported. " + "Provide 'suffixes' as a tuple instead." + ) + + to_rename = left.intersection(right) + if len(to_rename) == 0: + return left, right + + lsuffix, rsuffix = suffixes + + if not lsuffix and not rsuffix: + raise ValueError(f"columns overlap but no suffix specified: {to_rename}") + + def renamer(x, suffix: str | None): + """ + Rename the left and right indices. + + If there is overlap, and suffix is not None, add + suffix, otherwise, leave it as-is. + + Parameters + ---------- + x : original column name + suffix : str or None + + Returns + ------- + x : renamed column name + """ + if x in to_rename and suffix is not None: + return f"{x}{suffix}" + return x + + lrenamer = partial(renamer, suffix=lsuffix) + rrenamer = partial(renamer, suffix=rsuffix) + + llabels = left._transform_index(lrenamer) + rlabels = right._transform_index(rrenamer) + + dups = [] + if not llabels.is_unique: + # Only warn when duplicates are caused because of suffixes, already duplicated + # columns in origin should not warn + dups = llabels[(llabels.duplicated()) & (~left.duplicated())].tolist() + if not rlabels.is_unique: + dups.extend(rlabels[(rlabels.duplicated()) & (~right.duplicated())].tolist()) + if dups: + raise MergeError( + f"Passing 'suffixes' which cause duplicate columns {set(dups)} is " + f"not allowed.", + ) + + return llabels, rlabels diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index 61112542fb9d8..abe25806d1a71 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -1,7 +1,9 @@ """ parquet compat """ from __future__ import annotations +import ast import io +import json import os from typing import ( TYPE_CHECKING, @@ -184,6 +186,11 @@ def write( table = self.api.Table.from_pandas(df, **from_pandas_kwargs) + df_metadata = {"df.attrs": json.dumps(df.attrs)} + existing_metadata = table.schema.metadata + merged_metadata = {**existing_metadata, **df_metadata} + table = table.replace_schema_metadata(merged_metadata) + path_or_handle, handles, filesystem = _get_path_or_handle( path, filesystem, @@ -263,6 +270,11 @@ def read( if manager == "array": result = result._as_manager("array", copy=False) + + result.attrs = ast.literal_eval( + pa_table.schema.metadata[b"df.attrs"].decode("utf-8") + ) + return result finally: if handles is not None: diff --git a/pandas/tests/indexes/numeric/test_setops 2.py b/pandas/tests/indexes/numeric/test_setops 2.py new file mode 100644 index 0000000000000..2276b10db1fe3 --- /dev/null +++ b/pandas/tests/indexes/numeric/test_setops 2.py @@ -0,0 +1,154 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +import pandas._testing as tm +from pandas.core.indexes.api import ( + Index, + RangeIndex, +) + + +@pytest.fixture +def index_large(): + # large values used in TestUInt64Index where no compat needed with int64/float64 + large = [2**63, 2**63 + 10, 2**63 + 15, 2**63 + 20, 2**63 + 25] + return Index(large, dtype=np.uint64) + + +class TestSetOps: + @pytest.mark.parametrize("dtype", ["f8", "u8", "i8"]) + def test_union_non_numeric(self, dtype): + # corner case, non-numeric + index = Index(np.arange(5, dtype=dtype), dtype=dtype) + assert index.dtype == dtype + + other = Index([datetime.now() + timedelta(i) for i in range(4)], dtype=object) + result = index.union(other) + expected = Index(np.concatenate((index, other))) + tm.assert_index_equal(result, expected) + + result = other.union(index) + expected = Index(np.concatenate((other, index))) + tm.assert_index_equal(result, expected) + + def test_intersection(self): + index = Index(range(5), dtype=np.int64) + + other = Index([1, 2, 3, 4, 5]) + result = index.intersection(other) + expected = Index(np.sort(np.intersect1d(index.values, other.values))) + tm.assert_index_equal(result, expected) + + result = other.intersection(index) + expected = Index( + np.sort(np.asarray(np.intersect1d(index.values, other.values))) + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "uint64"]) + def test_int_float_union_dtype(self, dtype): + # https://github.com/pandas-dev/pandas/issues/26778 + # [u]int | float -> float + index = Index([0, 2, 3], dtype=dtype) + other = Index([0.5, 1.5], dtype=np.float64) + expected = Index([0.0, 0.5, 1.5, 2.0, 3.0], dtype=np.float64) + result = index.union(other) + tm.assert_index_equal(result, expected) + + result = other.union(index) + tm.assert_index_equal(result, expected) + + def test_range_float_union_dtype(self): + # https://github.com/pandas-dev/pandas/issues/26778 + index = RangeIndex(start=0, stop=3) + other = Index([0.5, 1.5], dtype=np.float64) + result = index.union(other) + expected = Index([0.0, 0.5, 1, 1.5, 2.0], dtype=np.float64) + tm.assert_index_equal(result, expected) + + result = other.union(index) + tm.assert_index_equal(result, expected) + + def test_float64_index_difference(self): + # https://github.com/pandas-dev/pandas/issues/35217 + float_index = Index([1.0, 2, 3]) + string_index = Index(["1", "2", "3"]) + + result = float_index.difference(string_index) + tm.assert_index_equal(result, float_index) + + result = string_index.difference(float_index) + tm.assert_index_equal(result, string_index) + + def test_intersection_uint64_outside_int64_range(self, index_large): + other = Index([2**63, 2**63 + 5, 2**63 + 10, 2**63 + 15, 2**63 + 20]) + result = index_large.intersection(other) + expected = Index(np.sort(np.intersect1d(index_large.values, other.values))) + tm.assert_index_equal(result, expected) + + result = other.intersection(index_large) + expected = Index( + np.sort(np.asarray(np.intersect1d(index_large.values, other.values))) + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "index2,keeps_name", + [ + (Index([4, 7, 6, 5, 3], name="index"), True), + (Index([4, 7, 6, 5, 3], name="other"), False), + ], + ) + def test_intersection_monotonic(self, index2, keeps_name, sort): + index1 = Index([5, 3, 2, 4, 1], name="index") + expected = Index([5, 3, 4]) + + if keeps_name: + expected.name = "index" + + result = index1.intersection(index2, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + def test_symmetric_difference(self, sort): + # smoke + index1 = Index([5, 2, 3, 4], name="index1") + index2 = Index([2, 3, 4, 1]) + result = index1.symmetric_difference(index2, sort=sort) + expected = Index([5, 1]) + assert tm.equalContents(result, expected) + assert result.name is None + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + +class TestSetOpsSort: + @pytest.mark.parametrize("slice_", [slice(None), slice(0)]) + def test_union_sort_other_special(self, slice_): + # https://github.com/pandas-dev/pandas/issues/24959 + + idx = Index([1, 0, 2]) + # default, sort=None + other = idx[slice_] + tm.assert_index_equal(idx.union(other), idx) + tm.assert_index_equal(other.union(idx), idx) + + # sort=False + tm.assert_index_equal(idx.union(other, sort=False), idx) + + @pytest.mark.parametrize("slice_", [slice(None), slice(0)]) + def test_union_sort_special_true(self, slice_): + idx = Index([1, 0, 2]) + # default, sort=None + other = idx[slice_] + + result = idx.union(other, sort=True) + expected = Index([0, 1, 2]) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/indexes/ranges/test_indexing 2.py b/pandas/tests/indexes/ranges/test_indexing 2.py new file mode 100644 index 0000000000000..84a78ad86c3d3 --- /dev/null +++ b/pandas/tests/indexes/ranges/test_indexing 2.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest + +from pandas import ( + Index, + RangeIndex, +) +import pandas._testing as tm + + +class TestGetIndexer: + def test_get_indexer(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target) + expected = np.array([0, -1, 1, -1, 2, -1, 3, -1, 4, -1], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_pad(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target, method="pad") + expected = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_backfill(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target, method="backfill") + expected = np.array([0, 1, 1, 2, 2, 3, 3, 4, 4, 5], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_limit(self): + # GH#28631 + idx = RangeIndex(4) + target = RangeIndex(6) + result = idx.get_indexer(target, method="pad", limit=1) + expected = np.array([0, 1, 2, 3, 3, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("stop", [0, -1, -2]) + def test_get_indexer_decreasing(self, stop): + # GH#28678 + index = RangeIndex(7, stop, -3) + result = index.get_indexer(range(9)) + expected = np.array([-1, 2, -1, -1, 1, -1, -1, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +class TestTake: + def test_take_preserve_name(self): + index = RangeIndex(1, 5, name="foo") + taken = index.take([3, 0, 1]) + assert index.name == taken.name + + def test_take_fill_value(self): + # GH#12631 + idx = RangeIndex(1, 4, name="xxx") + result = idx.take(np.array([1, 0, -1])) + expected = Index([2, 1, 3], dtype=np.int64, name="xxx") + tm.assert_index_equal(result, expected) + + # fill_value + msg = "Unable to fill values because RangeIndex cannot contain NA" + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -1]), fill_value=True) + + # allow_fill=False + result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True) + expected = Index([2, 1, 3], dtype=np.int64, name="xxx") + tm.assert_index_equal(result, expected) + + msg = "Unable to fill values because RangeIndex cannot contain NA" + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -2]), fill_value=True) + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -5]), fill_value=True) + + msg = "index -5 is out of bounds for (axis 0 with )?size 3" + with pytest.raises(IndexError, match=msg): + idx.take(np.array([1, -5])) + + +class TestWhere: + def test_where_putmask_range_cast(self): + # GH#43240 + idx = RangeIndex(0, 5, name="test") + + mask = np.array([True, True, False, False, False]) + result = idx.putmask(mask, 10) + expected = Index([10, 10, 2, 3, 4], dtype=np.int64, name="test") + tm.assert_index_equal(result, expected) + + result = idx.where(~mask, 10) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 0d8afbf220b0c..4dc96520b2dbc 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -1192,6 +1192,14 @@ def test_partition_on_supported(self, tmp_path, fp, df_full): actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats assert len(actual_partition_cols) == 2 + def test_df_attrs_persistence(self, tmp_path): + path = tmp_path / "test_df_metadata.p" + df = pd.DataFrame(data={1: [1]}) + df.attrs = {"Test attribute": 1} + df.to_parquet(path) + new_df = read_parquet(path) + assert new_df.attrs == df.attrs + def test_error_on_using_partition_cols_and_partition_on( self, tmp_path, fp, df_full ): diff --git a/pandas/tests/tseries/offsets/test_year 2.py b/pandas/tests/tseries/offsets/test_year 2.py new file mode 100644 index 0000000000000..480c875c36e04 --- /dev/null +++ b/pandas/tests/tseries/offsets/test_year 2.py @@ -0,0 +1,334 @@ +""" +Tests for the following offsets: +- YearBegin +- YearEnd +""" +from __future__ import annotations + +from datetime import datetime + +import numpy as np +import pytest + +from pandas.compat import is_numpy_dev + +from pandas import Timestamp +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + YearBegin, + YearEnd, +) + + +class TestYearBegin: + def test_misspecified(self): + with pytest.raises(ValueError, match="Month must go from 1 to 12"): + YearBegin(month=13) + + offset_cases = [] + offset_cases.append( + ( + YearBegin(), + { + datetime(2008, 1, 1): datetime(2009, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2005, 12, 30): datetime(2006, 1, 1), + datetime(2005, 12, 31): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2005, 12, 30): datetime(2006, 1, 1), + datetime(2005, 12, 31): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(3), + { + datetime(2008, 1, 1): datetime(2011, 1, 1), + datetime(2008, 6, 30): datetime(2011, 1, 1), + datetime(2008, 12, 31): datetime(2011, 1, 1), + datetime(2005, 12, 30): datetime(2008, 1, 1), + datetime(2005, 12, 31): datetime(2008, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 1, 1), + datetime(2007, 1, 15): datetime(2007, 1, 1), + datetime(2008, 6, 30): datetime(2008, 1, 1), + datetime(2008, 12, 31): datetime(2008, 1, 1), + datetime(2006, 12, 29): datetime(2006, 1, 1), + datetime(2006, 12, 30): datetime(2006, 1, 1), + datetime(2007, 1, 1): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-2), + { + datetime(2007, 1, 1): datetime(2005, 1, 1), + datetime(2008, 6, 30): datetime(2007, 1, 1), + datetime(2008, 12, 31): datetime(2007, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(month=4), + { + datetime(2007, 4, 1): datetime(2008, 4, 1), + datetime(2007, 4, 15): datetime(2008, 4, 1), + datetime(2007, 3, 1): datetime(2007, 4, 1), + datetime(2007, 12, 15): datetime(2008, 4, 1), + datetime(2012, 1, 31): datetime(2012, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(0, month=4), + { + datetime(2007, 4, 1): datetime(2007, 4, 1), + datetime(2007, 3, 1): datetime(2007, 4, 1), + datetime(2007, 12, 15): datetime(2008, 4, 1), + datetime(2012, 1, 31): datetime(2012, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(4, month=4), + { + datetime(2007, 4, 1): datetime(2011, 4, 1), + datetime(2007, 4, 15): datetime(2011, 4, 1), + datetime(2007, 3, 1): datetime(2010, 4, 1), + datetime(2007, 12, 15): datetime(2011, 4, 1), + datetime(2012, 1, 31): datetime(2015, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-1, month=4), + { + datetime(2007, 4, 1): datetime(2006, 4, 1), + datetime(2007, 3, 1): datetime(2006, 4, 1), + datetime(2007, 12, 15): datetime(2007, 4, 1), + datetime(2012, 1, 31): datetime(2011, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-3, month=4), + { + datetime(2007, 4, 1): datetime(2004, 4, 1), + datetime(2007, 3, 1): datetime(2004, 4, 1), + datetime(2007, 12, 15): datetime(2005, 4, 1), + datetime(2012, 1, 31): datetime(2009, 4, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearBegin(), datetime(2007, 1, 3), False), + (YearBegin(), datetime(2008, 1, 1), True), + (YearBegin(), datetime(2006, 12, 31), False), + (YearBegin(), datetime(2006, 1, 2), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestYearEnd: + def test_misspecified(self): + with pytest.raises(ValueError, match="Month must go from 1 to 12"): + YearEnd(month=13) + + offset_cases = [] + offset_cases.append( + ( + YearEnd(), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2009, 12, 31), + datetime(2005, 12, 30): datetime(2005, 12, 31), + datetime(2005, 12, 31): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2008, 12, 31), + datetime(2005, 12, 30): datetime(2005, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 31), + datetime(2008, 6, 30): datetime(2007, 12, 31), + datetime(2008, 12, 31): datetime(2007, 12, 31), + datetime(2006, 12, 29): datetime(2005, 12, 31), + datetime(2006, 12, 30): datetime(2005, 12, 31), + datetime(2007, 1, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-2), + { + datetime(2007, 1, 1): datetime(2005, 12, 31), + datetime(2008, 6, 30): datetime(2006, 12, 31), + datetime(2008, 12, 31): datetime(2006, 12, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearEnd(), datetime(2007, 12, 31), True), + (YearEnd(), datetime(2008, 1, 1), False), + (YearEnd(), datetime(2006, 12, 31), True), + (YearEnd(), datetime(2006, 12, 29), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestYearEndDiffMonth: + offset_cases = [] + offset_cases.append( + ( + YearEnd(month=3), + { + datetime(2008, 1, 1): datetime(2008, 3, 31), + datetime(2008, 2, 15): datetime(2008, 3, 31), + datetime(2008, 3, 31): datetime(2009, 3, 31), + datetime(2008, 3, 30): datetime(2008, 3, 31), + datetime(2005, 3, 31): datetime(2006, 3, 31), + datetime(2006, 7, 30): datetime(2007, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(0, month=3), + { + datetime(2008, 1, 1): datetime(2008, 3, 31), + datetime(2008, 2, 28): datetime(2008, 3, 31), + datetime(2008, 3, 31): datetime(2008, 3, 31), + datetime(2005, 3, 30): datetime(2005, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-1, month=3), + { + datetime(2007, 1, 1): datetime(2006, 3, 31), + datetime(2008, 2, 28): datetime(2007, 3, 31), + datetime(2008, 3, 31): datetime(2007, 3, 31), + datetime(2006, 3, 29): datetime(2005, 3, 31), + datetime(2006, 3, 30): datetime(2005, 3, 31), + datetime(2007, 3, 1): datetime(2006, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-2, month=3), + { + datetime(2007, 1, 1): datetime(2005, 3, 31), + datetime(2008, 6, 30): datetime(2007, 3, 31), + datetime(2008, 3, 31): datetime(2006, 3, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearEnd(month=3), datetime(2007, 3, 31), True), + (YearEnd(month=3), datetime(2008, 1, 1), False), + (YearEnd(month=3), datetime(2006, 3, 31), True), + (YearEnd(month=3), datetime(2006, 3, 29), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +@pytest.mark.xfail(is_numpy_dev, reason="result year is 1973, unclear why") +def test_add_out_of_pydatetime_range(): + # GH#50348 don't raise in Timestamp.replace + ts = Timestamp(np.datetime64("-20000-12-31")) + off = YearEnd() + + result = ts + off + expected = Timestamp(np.datetime64("-19999-12-31")) + assert result == expected