Skip to content

Commit

Permalink
Merge pull request #2759 from spectre-ns/concatenate_access
Browse files Browse the repository at this point in the history
[Optimization] Updated `concatenate_access` and `stack_access` to remove allocations
  • Loading branch information
JohanMabille authored Jan 12, 2024
2 parents 1aa7099 + 4045fb2 commit f268c6d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 27 deletions.
98 changes: 72 additions & 26 deletions include/xtensor/xbuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,22 +494,47 @@ namespace xt
using size_type = std::size_t;
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;

template <class S>
inline value_type access(const tuple_type& t, size_type axis, S index) const
template <class It>
inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
{
auto match = [&index, axis](auto& arr)
// trim off extra indices if provided to match behavior of containers
auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
size_t axis_dim = *(first + axis + dim_offset);
auto match = [&](auto& arr)
{
if (index[axis] >= arr.shape()[axis])
if (axis_dim >= arr.shape()[axis])
{
index[axis] -= arr.shape()[axis];
axis_dim -= arr.shape()[axis];
return false;
}
return true;
};

auto get = [&index](auto& arr)
auto get = [&](auto& arr)
{
return arr[index];
size_t offset = 0;
const size_t end = arr.dimension();
for (size_t i = 0; i < end; i++)
{
const auto& shape = arr.shape();
const size_t stride = std::accumulate(
shape.begin() + i + 1,
shape.end(),
1,
std::multiplies<size_t>()
);
if (i == axis)
{
offset += axis_dim * stride;
}
else
{
const auto len = (*(first + i + dim_offset));
offset += len * stride;
}
}
const auto element = arr.begin() + offset;
return *element;
};

size_type i = 0;
Expand All @@ -533,48 +558,68 @@ namespace xt
using size_type = std::size_t;
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;

template <class S>
inline value_type access(const tuple_type& t, size_type axis, S index) const
template <class It>
inline value_type access(const tuple_type& t, size_type axis, It first, It) const
{
auto get_item = [&index](auto& arr)
auto get_item = [&](auto& arr)
{
return arr[index];
size_t offset = 0;
const size_t end = arr.dimension();
size_t after_axis = 0;
for (size_t i = 0; i < end; i++)
{
if (i == axis)
{
after_axis = 1;
}
const auto& shape = arr.shape();
const size_t stride = std::accumulate(
shape.begin() + i + 1,
shape.end(),
1,
std::multiplies<size_t>()
);
const auto len = (*(first + i + after_axis));
offset += len * stride;
}
const auto element = arr.begin() + offset;
return *element;
};
size_type i = index[axis];
index.erase(index.begin() + std::ptrdiff_t(axis));
size_type i = *(first + axis);
return apply<value_type>(i, get_item, t);
}
};

template <class... CT>
class vstack_access : private concatenate_access<CT...>,
private stack_access<CT...>
class vstack_access
{
public:

using tuple_type = std::tuple<CT...>;
using size_type = std::size_t;
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;

using concatenate_base = concatenate_access<CT...>;
using stack_base = stack_access<CT...>;

template <class S>
inline value_type access(const tuple_type& t, size_type axis, S index) const
template <class It>
inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
{
if (std::get<0>(t).dimension() == 1)
{
return stack_base::access(t, axis, index);
return stack.access(t, axis, first, last);
}
else
{
return concatenate_base::access(t, axis, index);
return concatonate.access(t, axis, first, last);
}
}

private:

concatenate_access<CT...> concatonate;
stack_access<CT...> stack;
};

template <template <class...> class F, class... CT>
class concatenate_invoker : private F<CT...>
class concatenate_invoker
{
public:

Expand All @@ -592,18 +637,19 @@ namespace xt
inline value_type operator()(Args... args) const
{
// TODO: avoid memory allocation
return this->access(m_t, m_axis, xindex({static_cast<size_type>(args)...}));
xindex index({static_cast<size_type>(args)...});
return access_method.access(m_t, m_axis, index.begin(), index.end());
}

template <class It>
inline value_type element(It first, It last) const
{
// TODO: avoid memory allocation
return this->access(m_t, m_axis, xindex(first, last));
return access_method.access(m_t, m_axis, first, last);
}

private:

F<CT...> access_method;
tuple_type m_t;
size_type m_axis;
};
Expand Down
2 changes: 1 addition & 1 deletion test/test_xbuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ namespace xt
ASSERT_EQ(11, c(1, 1, 1, 2));
ASSERT_EQ(11, c(1, 1, 2, 2));

auto e = arange(1, 4);
xarray<double> e = arange(1, 4);
xarray<double> f = {2, 3, 4};
xarray<double> k = stack(xtuple(e, f));
xarray<double> l = stack(xtuple(e, f), 1);
Expand Down

0 comments on commit f268c6d

Please sign in to comment.