Skip to content

Commit

Permalink
add possibility to use std::stable_sort with xt::argsort (#2681)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibHlln authored Sep 4, 2023
1 parent a948772 commit 4c1c290
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 21 deletions.
94 changes: 73 additions & 21 deletions include/xtensor/xsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,30 +266,76 @@ namespace xt
);
}

/*****************************
* Implementation of argsort *
*****************************/

/**
* Sorting method.
* Predefined methods for performing indirect sorting.
* @see argsort(const xexpression<E>&, std::ptrdiff_t, sorting_method)
*/
enum class sorting_method
{
/**
* Faster method but with no guarantee on preservation of order of equal elements
* https://en.cppreference.com/w/cpp/algorithm/sort.
*/
quick,
/**
* Slower method but with guarantee on preservation of order of equal elements
* https://en.cppreference.com/w/cpp/algorithm/stable_sort.
*/
stable,
};

namespace detail
{
template <class ConstRandomIt, class RandomIt, class Compare>
inline void
argsort_iter(ConstRandomIt data_begin, ConstRandomIt data_end, RandomIt idx_begin, RandomIt idx_end, Compare comp)
template <class ConstRandomIt, class RandomIt, class Compare, class Method>
inline void argsort_iter(
ConstRandomIt data_begin,
ConstRandomIt data_end,
RandomIt idx_begin,
RandomIt idx_end,
Compare comp,
Method method
)
{
XTENSOR_ASSERT(std::distance(data_begin, data_end) >= 0);
XTENSOR_ASSERT(std::distance(idx_begin, idx_end) == std::distance(data_begin, data_end));
(void) idx_end; // TODO(C++17) [[maybe_unused]] only used in assertion.

std::iota(idx_begin, idx_end, 0);
std::sort(
idx_begin,
idx_end,
[&](const auto i, const auto j)
switch (method)
{
case (sorting_method::quick):
{
return comp(*(data_begin + i), *(data_begin + j));
std::sort(
idx_begin,
idx_end,
[&](const auto i, const auto j)
{
return comp(*(data_begin + i), *(data_begin + j));
}
);
}
);
case (sorting_method::stable):
{
std::stable_sort(
idx_begin,
idx_end,
[&](const auto i, const auto j)
{
return comp(*(data_begin + i), *(data_begin + j));
}
);
}
}
}

template <class ConstRandomIt, class RandomIt>
template <class ConstRandomIt, class RandomIt, class Method>
inline void
argsort_iter(ConstRandomIt data_begin, ConstRandomIt data_end, RandomIt idx_begin, RandomIt idx_end)
argsort_iter(ConstRandomIt data_begin, ConstRandomIt data_end, RandomIt idx_begin, RandomIt idx_end, Method method)
{
return argsort_iter(
std::move(data_begin),
Expand All @@ -299,7 +345,8 @@ namespace xt
[](const auto& x, const auto& y) -> bool
{
return x < y;
}
},
method
);
}

Expand Down Expand Up @@ -359,8 +406,8 @@ namespace xt
typename T::temporary_type>::type;
};

template <class E, class R = typename detail::linear_argsort_result_type<E>::type>
inline auto flatten_argsort_impl(const xexpression<E>& e)
template <class E, class R = typename detail::linear_argsort_result_type<E>::type, class Method>
inline auto flatten_argsort_impl(const xexpression<E>& e, Method method)
{
const auto& de = e.derived_cast();

Expand All @@ -372,16 +419,17 @@ namespace xt
result_type result;
result.resize({de.size()});

detail::argsort_iter(de.cbegin(), de.cend(), result.begin(), result.end());
detail::argsort_iter(de.cbegin(), de.cend(), result.begin(), result.end(), method);

return result;
}
}

template <class E>
inline auto argsort(const xexpression<E>& e, placeholders::xtuph /*t*/)
inline auto
argsort(const xexpression<E>& e, placeholders::xtuph /*t*/, sorting_method method = sorting_method::quick)
{
return detail::flatten_argsort_impl(e);
return detail::flatten_argsort_impl(e, method);
}

/**
Expand All @@ -393,11 +441,15 @@ namespace xt
* @ingroup xt_xsort
* @param e xexpression to argsort
* @param axis axis along which argsort is performed
* @param method sorting algorithm to use
*
* @return argsorted index array
*
* @see xt::sorting_method
*/
template <class E>
inline auto argsort(const xexpression<E>& e, std::ptrdiff_t axis = -1)
inline auto
argsort(const xexpression<E>& e, std::ptrdiff_t axis = -1, sorting_method method = sorting_method::quick)
{
using eval_type = typename detail::sort_eval_type<E>::type;
using result_type = typename detail::argsort_result_type<eval_type>::type;
Expand All @@ -408,12 +460,12 @@ namespace xt

if (de.dimension() == 1)
{
return detail::flatten_argsort_impl<E, result_type>(e);
return detail::flatten_argsort_impl<E, result_type>(e, method);
}

const auto argsort = [](auto res_begin, auto res_end, auto ev_begin, auto ev_end)
const auto argsort = [&method](auto res_begin, auto res_end, auto ev_begin, auto ev_end)
{
detail::argsort_iter(ev_begin, ev_end, res_begin, res_end);
detail::argsort_iter(ev_begin, ev_end, res_begin, res_end, method);
};

if (ax == detail::leading_axis(de))
Expand Down
46 changes: 46 additions & 0 deletions test/test_xsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,52 @@ namespace xt
}
}

TEST(xsort, argsort_stable)
{
xt::xarray<int> a = {
2247, 2044, 2037, 1825, 1699, 1600, 1501, 1432, 1440, 1388, 1299, 1259, 1211, 1177, 1124, 1121,
1399, 1179, 1102, 1055, 1017, 1001, 979, 953, 925, 927, 1899, 4782, 2601, 3050, 1998, 3478,
5762, 8745, 7777, 4086, 2968, 2456, 2138, 3199, 7187, 7165, 3942, 11588, 4643, 3618, 3184, 6052,
3723, 3356, 2645, 2330, 2103, 1928, 1890, 3624, 3647, 5821, 2949, 4161, 3855, 5200, 3162, 3896,
19818, 5228, 3711, 4874, 4868, 4267, 2978, 2748, 2500, 2276, 2107, 1969, 1815, 1714, 1649, 1561,
1491, 1428, 1347, 1288, 1245, 1207, 1163, 1189, 1558, 1313, 1186, 1147, 1143, 1099, 1036, 1008,
982, 953, 924, 899, 880, 858, 836, 818, 791, 762, 742, 720, 701, 685, 674, 658,
643, 625, 609, 602, 594, 594, 625, 616, 707, 759, 929, 755, 684, 737, 804, 680,
639, 620, 706, 2474, 1406, 2686, 2037, 1410, 1130, 992, 901, 847, 798, 766, 750, 2793,
1574, 1087, 960, 883, 823, 782, 755, 730, 705, 677, 744, 932, 770, 794, 804, 728,
721, 765, 826, 795, 784, 875, 711, 795, 2295, 1265, 1115, 936, 817, 742, 713, 680,
647, 647, 671, 651, 692, 605, 576, 612, 564, 536, 526, 516, 506, 494, 473, 459,
451, 450, 441, 428, 430, 453, 467, 444, 413, 390, 376, 373, 368, 360, 359, 351,
348, 345, 343, 340, 337, 341, 529, 476, 448, 397, 367, 354, 342, 332, 320, 324,
327, 332, 340, 359, 331, 382, 371, 356, 333, 318, 316, 318, 312, 308, 300, 298,
399, 448, 669, 534, 402, 384, 361, 347, 337, 342, 349, 339, 333, 325, 324, 326,
324, 322, 320, 318, 322, 327, 327, 333, 329, 333, 353, 436, 516, 412, 402, 398,
551, 449, 416, 484, 801, 537, 457, 483, 534, 477, 483, 582, 575, 557, 504, 477,
515, 643, 627, 627, 546, 515, 497, 577, 1439, 1198, 1591, 1815, 1448, 1117, 960, 847,
809, 788, 749, 896, 870, 801, 752};

xt::xarray<std::size_t> ex = {
239, 238, 237, 236, 234, 233, 235, 259, 222, 258, 257, 260, 223, 254, 256, 253, 255, 224, 261,
262, 264, 228, 221, 225, 232, 252, 263, 265, 212, 248, 251, 211, 226, 213, 220, 249, 210, 209,
247, 208, 250, 207, 266, 219, 231, 206, 227, 205, 246, 218, 204, 230, 203, 202, 229, 245, 201,
217, 271, 240, 244, 270, 269, 200, 274, 195, 196, 267, 194, 199, 216, 241, 273, 193, 192, 197,
278, 191, 198, 190, 215, 281, 287, 279, 282, 275, 189, 294, 286, 188, 288, 293, 187, 268, 186,
214, 243, 280, 185, 277, 292, 272, 285, 184, 284, 182, 295, 283, 116, 117, 115, 181, 114, 183,
119, 129, 113, 118, 290, 291, 128, 112, 289, 176, 177, 179, 111, 242, 178, 110, 153, 127, 175,
124, 109, 180, 108, 152, 130, 120, 166, 174, 107, 160, 159, 151, 125, 106, 173, 154, 306, 142,
310, 123, 150, 121, 105, 161, 141, 156, 149, 164, 305, 104, 157, 163, 167, 140, 276, 309, 126,
158, 304, 172, 103, 148, 162, 102, 139, 303, 101, 308, 165, 100, 147, 307, 99, 138, 98, 24,
25, 122, 155, 171, 23, 97, 146, 302, 22, 96, 137, 21, 95, 20, 94, 19, 145, 93, 18,
170, 301, 15, 14, 136, 92, 91, 86, 13, 17, 90, 87, 297, 85, 12, 84, 11, 169, 83,
10, 89, 82, 9, 16, 132, 135, 81, 7, 296, 8, 300, 80, 6, 88, 79, 144, 298, 5,
78, 4, 77, 76, 299, 3, 54, 26, 53, 75, 30, 2, 134, 1, 52, 74, 38, 0, 73,
168, 51, 37, 131, 72, 28, 50, 133, 71, 143, 58, 36, 70, 29, 62, 46, 39, 49, 31,
45, 55, 56, 66, 48, 60, 63, 42, 35, 59, 69, 44, 27, 68, 67, 61, 65, 32, 57,
47, 41, 40, 34, 33, 43, 64};

EXPECT_EQ(ex, xt::argsort(a, {0}, xt::sorting_method::stable));
}

TEST(xsort, sort_easy)
{
xarray<double> a = {{5, 3, 1}, {4, 4, 4}};
Expand Down

0 comments on commit 4c1c290

Please sign in to comment.