Skip to content

Commit

Permalink
Added 'edge' pad mode for xt::pad
Browse files Browse the repository at this point in the history
  • Loading branch information
snehalv2002 authored and tdegeus committed Jun 14, 2023
1 parent 0621c5f commit e75ba74
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
29 changes: 24 additions & 5 deletions include/xtensor/xpad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ namespace xt
symmetric,
reflect,
wrap,
periodic
periodic,
edge
};

namespace detail
Expand Down Expand Up @@ -123,19 +124,28 @@ namespace xt
{
XTENSOR_ASSERT(nb <= e.shape(axis));
svs[axis] = xt::range(e.shape(axis), nb + e.shape(axis));
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::symmetric)
{
XTENSOR_ASSERT(nb <= e.shape(axis));
svs[axis] = xt::range(2 * nb - 1, nb - 1, -1);
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::reflect)
{
XTENSOR_ASSERT(nb <= e.shape(axis) - 1);
svs[axis] = xt::range(2 * nb, nb, -1);
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::edge)
{
svs[axis] = xt::range(nb, nb + 1);
xt::strided_view(out, svt) = xt::broadcast(
xt::strided_view(out, svs),
xt::strided_view(out, svt).shape()
);
}

xt::strided_view(out, svt) = xt::strided_view(out, svs);
}

if (ne > static_cast<size_type>(0))
Expand All @@ -146,6 +156,7 @@ namespace xt
{
XTENSOR_ASSERT(ne <= e.shape(axis));
svs[axis] = xt::range(nb, nb + ne);
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::symmetric)
{
Expand All @@ -158,6 +169,7 @@ namespace xt
{
svs[axis] = xt::range(nb + e.shape(axis) - 1, nb + e.shape(axis) - ne - 1, -1);
}
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::reflect)
{
Expand All @@ -170,9 +182,16 @@ namespace xt
{
svs[axis] = xt::range(nb + e.shape(axis) - 2, nb + e.shape(axis) - ne - 2, -1);
}
xt::strided_view(out, svt) = xt::strided_view(out, svs);
}
else if (mode == pad_mode::edge)
{
svs[axis] = xt::range(out.shape(axis) - ne - 1, out.shape(axis) - ne);
xt::strided_view(out, svt) = xt::broadcast(
xt::strided_view(out, svs),
xt::strided_view(out, svt).shape()
);
}

xt::strided_view(out, svt) = xt::strided_view(out, svs);
}

svs[axis] = xt::all();
Expand Down
50 changes: 50 additions & 0 deletions test/test_xpad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,56 @@ namespace xt
EXPECT_EQ(b, c);
}

TEST(xpad, edge_a)
{
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};

xt::xtensor<size_t, 2> b = {
{0, 0, 0, 0, 1, 2, 2, 2, 2},
{0, 0, 0, 0, 1, 2, 2, 2, 2},
{0, 0, 0, 0, 1, 2, 2, 2, 2},
{3, 3, 3, 3, 4, 5, 5, 5, 5},
{3, 3, 3, 3, 4, 5, 5, 5, 5},
{3, 3, 3, 3, 4, 5, 5, 5, 5}};

xt::xtensor<size_t, 2> c = xt::pad(a, {{2, 2}, {3, 3}}, xt::pad_mode::edge);

EXPECT_EQ(b, c);
}

TEST(xpad, edge_b)
{
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};

xt::xtensor<size_t, 2> b = {{0, 1, 2}, {3, 4, 5}};

xt::xtensor<size_t, 2> c = xt::pad(a, 0, xt::pad_mode::edge);

EXPECT_EQ(b, c);
}

TEST(xpad, edge_c)
{
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};

xt::xtensor<size_t, 2> b = {{0, 1, 2, 2, 2}, {3, 4, 5, 5, 5}, {3, 4, 5, 5, 5}};

xt::xtensor<size_t, 2> c = xt::pad(a, {{0, 1}, {0, 2}}, xt::pad_mode::edge);

EXPECT_EQ(b, c);
}

TEST(xpad, edge_d)
{
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};

xt::xtensor<size_t, 2> b = {{0, 0, 0, 1, 2}, {0, 0, 0, 1, 2}, {3, 3, 3, 4, 5}};

xt::xtensor<size_t, 2> c = xt::pad(a, {{1, 0}, {2, 0}}, xt::pad_mode::edge);

EXPECT_EQ(b, c);
}

TEST(xpad, wrap_a)
{
xt::xtensor<size_t, 2> a = {{0, 1, 2}, {3, 4, 5}};
Expand Down

0 comments on commit e75ba74

Please sign in to comment.