Skip to content

Commit

Permalink
Improve trailing zero removal; add new policy remove_compact (same as…
Browse files Browse the repository at this point in the history
… remove, but supposedly smaller code size)
  • Loading branch information
jk-jeon committed May 7, 2024
1 parent 0c38567 commit ff9d375
Showing 1 changed file with 168 additions and 131 deletions.
299 changes: 168 additions & 131 deletions include/dragonbox/dragonbox.h
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,42 @@ namespace jkj {
compressed_cache_holder<ieee754_binary64, Dummy>::pow5_table;
#endif

////////////////////////////////////////////////////////////////////////////////////////
// Forward declarations of user-specializable templates used in the main algorithm.
////////////////////////////////////////////////////////////////////////////////////////

// Remove trailing zeros from significand and add the number of removed zeros into
// exponent.
template <class TrailingZeroPolicy, class Format, class DecimalSignificand,
class DecimalExponentType>
struct remove_trailing_zeros_traits;

// Users can specialize this traits class to make Dragonbox work with their own formats.
// However, this requires detailed knowledge on how the algorithm works, so it is recommended to
// read through the paper.
template <class FormatTraits, class CacheEntryType, detail::stdr::size_t cache_bits_>
struct multiplication_traits;

// A collection of some common definitions to reduce boilerplate.
template <class FormatTraits, class CacheEntryType, detail::stdr::size_t cache_bits_>
struct multiplication_traits_base {
using format = typename FormatTraits::format;
static constexpr int significand_bits = format::significand_bits;
static constexpr int total_bits = format::total_bits;
using carrier_uint = typename FormatTraits::carrier_uint;
using cache_entry_type = CacheEntryType;
static constexpr int cache_bits = int(cache_bits_);

struct compute_mul_result {
carrier_uint integer_part;
bool is_integer;
};
struct compute_mul_parity_result {
bool parity;
bool is_integer;
};
};


////////////////////////////////////////////////////////////////////////////////////////
// Policies.
Expand Down Expand Up @@ -2335,121 +2371,37 @@ namespace jkj {
using trailing_zero_policy = remove_t;
static constexpr bool report_trailing_zeros = false;

// Remove trailing zeros from significand and add the number of removed zeros into
// exponent.
template <class Format, class DecimalSignificand, class DecimalExponentType>
JKJ_FORCEINLINE static JKJ_CONSTEXPR14
unsigned_decimal_fp<DecimalSignificand, DecimalExponentType, false>
on_trailing_zeros(DecimalSignificand significand,
DecimalExponentType exponent) noexcept {
remove_trailing_zeros_traits<
remove_t, Format, DecimalSignificand,
DecimalExponentType>::remove_trailing_zeros(significand, exponent);
return {significand, exponent};
}

assert(significand != 0);

JKJ_IF_CONSTEXPR(detail::stdr::is_same<Format, ieee754_binary32>::value) {
constexpr auto mod_inv_5 = UINT32_C(0xcccccccd);
constexpr auto mod_inv_25 = detail::stdr::uint_least32_t(
(mod_inv_5 * mod_inv_5) & UINT32_C(0xffffffff));

while (true) {
auto q = detail::bits::rotr<32>(
detail::stdr::uint_least32_t((significand * mod_inv_25) &
UINT32_C(0xffffffff)),
2);
if (q <= UINT32_C(0xffffffff) / 100) {
significand = q;
exponent += 2;
}
else {
break;
}
}
auto q = detail::bits::rotr<32>(
detail::stdr::uint_least32_t((significand * mod_inv_5) &
UINT32_C(0xffffffff)),
1);
if (q <= UINT32_C(0xffffffff) / 10) {
significand = q;
++exponent;
}
}
else {
#if JKJ_HAS_IF_CONSTEXPR
static_assert(detail::stdr::is_same<Format, ieee754_binary64>::value, "");
#endif
// Divide by 10^8 and reduce to 32-bits if divisible.
// Since significand <= (2^53 * 1000 - 1) / 1000 < 10^16, it is at most of
// 16 digits.

// This magic number is ceil(2^90 / 10^8).
constexpr auto magic_number = UINT64_C(12379400392853802749);
auto nm = detail::wuint::umul128(significand, magic_number);

// Is significand divisible by 10^8?
if ((nm.high() & ((detail::stdr::uint_least64_t(1) << (90 - 64)) - 1)) ==
0 &&
nm.low() < magic_number) {
// If yes, work with the quotient.
auto n32 = detail::stdr::uint_least32_t(nm.high() >> (90 - 64));

constexpr auto mod_inv_5 = UINT32_C(0xcccccccd);
constexpr auto mod_inv_25 = detail::stdr::uint_least32_t(
(mod_inv_5 * mod_inv_5) & UINT32_C(0xffffffff));

exponent += 8;
while (true) {
auto q = detail::bits::rotr<32>(
detail::stdr::uint_least32_t((n32 * mod_inv_25) &
UINT32_C(0xffffffff)),
2);
if (q <= UINT32_C(0xffffffff) / 100) {
n32 = q;
exponent += 2;
}
else {
break;
}
}
auto q = detail::bits::rotr<32>(
detail::stdr::uint_least32_t((n32 * mod_inv_5) &
UINT32_C(0xffffffff)),
1);
if (q <= UINT32_C(0xffffffff) / 10) {
n32 = q;
++exponent;
}
template <class Format, class DecimalSignificand, class DecimalExponentType>
static constexpr unsigned_decimal_fp<DecimalSignificand, DecimalExponentType, false>
no_trailing_zeros(DecimalSignificand significand,
DecimalExponentType exponent) noexcept {
return {significand, exponent};
}
} remove = {};

significand = n32;
}
else {
// If significand is not divisible by 10^8, work with n itself.
constexpr auto mod_inv_5 = UINT64_C(0xcccccccccccccccd);
constexpr auto mod_inv_25 = detail::stdr::uint_least64_t(
(mod_inv_5 * mod_inv_5) & UINT64_C(0xffffffffffffffff));

while (true) {
auto q = detail::bits::rotr<64>(
detail::stdr::uint_least64_t((significand * mod_inv_25) &
UINT64_C(0xffffffffffffffff)),
2);
if (q <= UINT64_C(0xffffffffffffffff) / 100) {
significand = q;
exponent += 2;
}
else {
break;
}
}
auto q = detail::bits::rotr<64>(
detail::stdr::uint_least64_t((significand * mod_inv_5) &
UINT64_C(0xffffffffffffffff)),
1);
if (q <= UINT64_C(0xffffffffffffffff) / 10) {
significand = q;
++exponent;
}
}
}
JKJ_INLINE_VARIABLE struct remove_compact_t {
using trailing_zero_policy = remove_compact_t;
static constexpr bool report_trailing_zeros = false;

template <class Format, class DecimalSignificand, class DecimalExponentType>
JKJ_FORCEINLINE static JKJ_CONSTEXPR14
unsigned_decimal_fp<DecimalSignificand, DecimalExponentType, false>
on_trailing_zeros(DecimalSignificand significand,
DecimalExponentType exponent) noexcept {
remove_trailing_zeros_traits<
remove_compact_t, Format, DecimalSignificand,
DecimalExponentType>::remove_trailing_zeros(significand, exponent);
return {significand, exponent};
}

Expand All @@ -2459,7 +2411,7 @@ namespace jkj {
DecimalExponentType exponent) noexcept {
return {significand, exponent};
}
} remove = {};
} remove_compact = {};

JKJ_INLINE_VARIABLE struct report_t {
using trailing_zero_policy = report_t;
Expand Down Expand Up @@ -2986,33 +2938,118 @@ namespace jkj {
}

////////////////////////////////////////////////////////////////////////////////////////
// Provides multiplication methods used in the main algorithm.
// Specializations of user-specializable templates used in the main algorithm.
////////////////////////////////////////////////////////////////////////////////////////

// Users can specialize this traits class to make Dragonbox work with their own formats.
// However, this requires detailed knowledge on how the algorithm works, so it is recommended to
// read through the paper.
template <class FormatTraits, class CacheEntryType, detail::stdr::size_t cache_bits_>
struct multiplication_traits;
template <class DecimalExponentType>
struct remove_trailing_zeros_traits<policy::trailing_zero::remove_t, ieee754_binary32,
detail::stdr::uint_least32_t, DecimalExponentType> {
JKJ_FORCEINLINE static JKJ_CONSTEXPR14 void
remove_trailing_zeros(detail::stdr::uint_least32_t& significand,
DecimalExponentType& exponent) noexcept {
// See https://github.com/jk-jeon/rtz_benchmark.
// The idea of branchless search below is by reddit users r/pigeon768 and
// r/TheoreticalDumbass.

auto r = detail::bits::rotr<32>(
detail::stdr::uint_least32_t(significand * UINT32_C(184254097)), 4);
auto b = r < UINT32_C(429497);
auto s = std::size_t(b);
significand = b ? r : significand;

r = detail::bits::rotr<32>(
detail::stdr::uint_least32_t(significand * UINT32_C(42949673)), 2);
b = r < UINT32_C(42949673);
s = s * 2 + b;
significand = b ? r : significand;

r = detail::bits::rotr<32>(
detail::stdr::uint_least32_t(significand * UINT32_C(1288490189)), 1);
b = r < UINT32_C(429496730);
s = s * 2 + b;
significand = b ? r : significand;

exponent += s;
}
};

// A collection of some common definitions to reduce boilerplate.
template <class FormatTraits, class CacheEntryType, detail::stdr::size_t cache_bits_>
struct multiplication_traits_base {
using format = typename FormatTraits::format;
static constexpr int significand_bits = format::significand_bits;
static constexpr int total_bits = format::total_bits;
using carrier_uint = typename FormatTraits::carrier_uint;
using cache_entry_type = CacheEntryType;
static constexpr int cache_bits = int(cache_bits_);
template <class DecimalExponentType>
struct remove_trailing_zeros_traits<policy::trailing_zero::remove_t, ieee754_binary64,
detail::stdr::uint_least64_t, DecimalExponentType> {
JKJ_FORCEINLINE static JKJ_CONSTEXPR14 void
remove_trailing_zeros(detail::stdr::uint_least64_t& significand,
DecimalExponentType& exponent) noexcept {
// See https://github.com/jk-jeon/rtz_benchmark.
// The idea of branchless search below is by reddit users r/pigeon768 and
// r/TheoreticalDumbass.

auto r = detail::bits::rotr<64>(
detail::stdr::uint_least64_t(significand * UINT64_C(28999941890838049)), 8);
auto b = r < UINT64_C(184467440738);
auto s = std::size_t(b);
significand = b ? r : significand;

r = detail::bits::rotr<64>(
detail::stdr::uint_least64_t(significand * UINT64_C(182622766329724561)), 4);
b = r < UINT64_C(1844674407370956);
s = s * 2 + b;
significand = b ? r : significand;

r = detail::bits::rotr<64>(
detail::stdr::uint_least64_t(significand * UINT64_C(10330176681277348905)), 2);
b = r < UINT64_C(184467440737095517);
s = s * 2 + b;
significand = b ? r : significand;

r = detail::bits::rotr<64>(
detail::stdr::uint_least64_t(significand * UINT32_C(14757395258967641293)), 1);
b = r < UINT64_C(1844674407370955162);
s = s * 2 + b;
significand = b ? r : significand;

exponent += s;
}
};

struct compute_mul_result {
carrier_uint integer_part;
bool is_integer;
};
struct compute_mul_parity_result {
bool parity;
bool is_integer;
};
template <class DecimalExponentType>
struct remove_trailing_zeros_traits<policy::trailing_zero::remove_compact_t, ieee754_binary32,
detail::stdr::uint_least32_t, DecimalExponentType> {
JKJ_FORCEINLINE static JKJ_CONSTEXPR14 void
remove_trailing_zeros(detail::stdr::uint_least32_t& significand,
DecimalExponentType& exponent) noexcept {
// See https://github.com/jk-jeon/rtz_benchmark.
while (true) {
auto const r = detail::stdr::uint_least32_t(significand * UINT32_C(1288490189));
if (r < UINT32_C(429496731)) {
significand = detail::stdr::uint_least32_t(r >> 1);
exponent += 1;
}
else {
break;
}
}
}
};

template <class DecimalExponentType>
struct remove_trailing_zeros_traits<policy::trailing_zero::remove_compact_t, ieee754_binary64,
detail::stdr::uint_least64_t, DecimalExponentType> {
JKJ_FORCEINLINE static JKJ_CONSTEXPR14 void
remove_trailing_zeros(detail::stdr::uint_least64_t& significand,
DecimalExponentType& exponent) noexcept {
// See https://github.com/jk-jeon/rtz_benchmark.
while (true) {
auto const r =
detail::stdr::uint_least64_t(significand * UINT64_C(5534023222112865485));
if (r < UINT64_C(1844674407370955163)) {
significand = detail::stdr::uint_least64_t(r >> 1);
exponent += 1;
}
else {
break;
}
}
}
};

template <class ExponentInt>
Expand Down

0 comments on commit ff9d375

Please sign in to comment.