diff --git a/include/xtensor/xbuilder.hpp b/include/xtensor/xbuilder.hpp index 5f21cd057..d0fcad32d 100644 --- a/include/xtensor/xbuilder.hpp +++ b/include/xtensor/xbuilder.hpp @@ -494,22 +494,47 @@ namespace xt using size_type = std::size_t; using value_type = xtl::promote_type_t::value_type...>; - template - inline value_type access(const tuple_type& t, size_type axis, S index) const + template + inline value_type access(const tuple_type& t, size_type axis, It first, It last) const { - auto match = [&index, axis](auto& arr) + // trim off extra indices if provided to match behavior of containers + auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension(); + size_t axis_dim = *(first + axis + dim_offset); + auto match = [&](auto& arr) { - if (index[axis] >= arr.shape()[axis]) + if (axis_dim >= arr.shape()[axis]) { - index[axis] -= arr.shape()[axis]; + axis_dim -= arr.shape()[axis]; return false; } return true; }; - auto get = [&index](auto& arr) + auto get = [&](auto& arr) { - return arr[index]; + size_t offset = 0; + const size_t end = arr.dimension(); + for (size_t i = 0; i < end; i++) + { + const auto& shape = arr.shape(); + const size_t stride = std::accumulate( + shape.begin() + i + 1, + shape.end(), + 1, + std::multiplies() + ); + if (i == axis) + { + offset += axis_dim * stride; + } + else + { + const auto len = (*(first + i + dim_offset)); + offset += len * stride; + } + } + const auto element = arr.begin() + offset; + return *element; }; size_type i = 0; @@ -533,22 +558,40 @@ namespace xt using size_type = std::size_t; using value_type = xtl::promote_type_t::value_type...>; - template - inline value_type access(const tuple_type& t, size_type axis, S index) const + template + inline value_type access(const tuple_type& t, size_type axis, It first, It) const { - auto get_item = [&index](auto& arr) + auto get_item = [&](auto& arr) { - return arr[index]; + size_t offset = 0; + const size_t end = arr.dimension(); + size_t after_axis = 0; + for (size_t i = 0; i < end; i++) + { + if (i == axis) + { + after_axis = 1; + } + const auto& shape = arr.shape(); + const size_t stride = std::accumulate( + shape.begin() + i + 1, + shape.end(), + 1, + std::multiplies() + ); + const auto len = (*(first + i + after_axis)); + offset += len * stride; + } + const auto element = arr.begin() + offset; + return *element; }; - size_type i = index[axis]; - index.erase(index.begin() + std::ptrdiff_t(axis)); + size_type i = *(first + axis); return apply(i, get_item, t); } }; template - class vstack_access : private concatenate_access, - private stack_access + class vstack_access { public: @@ -556,25 +599,27 @@ namespace xt using size_type = std::size_t; using value_type = xtl::promote_type_t::value_type...>; - using concatenate_base = concatenate_access; - using stack_base = stack_access; - - template - inline value_type access(const tuple_type& t, size_type axis, S index) const + template + inline value_type access(const tuple_type& t, size_type axis, It first, It last) const { if (std::get<0>(t).dimension() == 1) { - return stack_base::access(t, axis, index); + return stack.access(t, axis, first, last); } else { - return concatenate_base::access(t, axis, index); + return concatonate.access(t, axis, first, last); } } + + private: + + concatenate_access concatonate; + stack_access stack; }; template