diff --git a/content/number-theory/Factor.h b/content/number-theory/Factor.h index a78192826..3558e6a77 100644 --- a/content/number-theory/Factor.h +++ b/content/number-theory/Factor.h @@ -30,12 +30,21 @@ * significantly. * * Subtle implementation notes: - * - we operate on residues in [1, n]; modmul can be proven to work for those * - prd starts off as 2 to handle the case n = 4; it's harmless for other n * since we're guaranteed that n > 2. (Pollard rho has problems with prime * powers in general, but all larger ones happen to work.) * - t starts off as 30 to make the first gcd check come earlier, as an * optimization for small numbers. + * - we vary f between restarts because the cycle finding algorithm does not + * find the first element in the cycle but rather one at distance k*|cycle| + * from the start, and that can result in continual failures if all cycles + * have the same size for all prime factors. E.g. fixing f(x) = x^2 + 1 would + * loop infinitely for n = 352523 * 352817, where all cycles have size 821. + * - we operate on residues in [i, n + i) which modmul is not designed to + * handle, but specifically modmul(x, x) still turns out to work for small + * enough i. (With reference to the proof in modmul-proof.tex, the argument + * for "S is in [-c, 2c)" goes through unchanged, while S < 2^63 now follows + * from S < 2c and S = x^2 (mod c) together implying S < c + i^2.) */ #pragma once @@ -43,8 +52,8 @@ #include "MillerRabin.h" ull pollard(ull n) { - auto f = [n](ull x) { return modmul(x, x, n) + 1; }; ull x = 0, y = 0, t = 30, prd = 2, i = 1, q; + auto f = [&](ull x) { return modmul(x, x, n) + i; }; while (t++ % 40 || __gcd(prd, n) == 1) { if (x == y) x = ++i, y = f(x); if ((q = modmul(prd, max(x,y) - min(x,y), n))) prd = q; diff --git a/stress-tests/number-theory/Factor.cpp b/stress-tests/number-theory/Factor.cpp index d5b982d38..8e764d593 100644 --- a/stress-tests/number-theory/Factor.cpp +++ b/stress-tests/number-theory/Factor.cpp @@ -32,5 +32,11 @@ int main() { auto res = factor(n); assertValid(n, res); } + rep(i,0,1e5) { + // max number that modmul can handle + ull n = 7268172458553106874 - i; + auto res = factor(n); + assertValid(n, res); + } cout<<"Tests passed!"< uni(1, lim); uniform_int_distribution uniSmall(0, lim / 10000); + int it = 0; for (int i = 0;; i++) { - if (expectSuccess && i >= ITERS) break; // if (i % 1'000'000 == 0) cerr << '.' << flush; ull c = i&1 ? lim - uniSmall(rng) : uni(rng); ull a = i&2 ? c - uniSmall(rng) : i&4 && !useDoubles ? (1ULL << 62) - uniSmall(rng) : uni(rng); ull b = i&8 ? c - uniSmall(rng) : uni(rng); if (a > c || b > c) continue; + if (expectSuccess && it++ >= ITERS) break; ull l = int128_modmul(a, b, c); ull r = useDoubles ? double_modmul(a, b, c) : modmul(a, b, c); if (l != r) { @@ -34,14 +35,43 @@ void test(ull lim, bool expectSuccess, bool useDoubles) { } } +void testSq(ull lim, bool expectSuccess, bool useDoubles) { + // Test that modmul works for squaring slightly beyond the stated bounds. + // Factor.h relies on this (and has a proof sketch in the doc comment). + mt19937_64 rng(1); + uniform_int_distribution uni(1, lim); + uniform_int_distribution uniSmall(0, lim / 10000); + uniform_int_distribution uniTiny(0, (int)(sqrt(lim) / 2)); + + for (int i = 0;; i++) { + if (expectSuccess && i >= ITERS) break; + // if (i % 1'000'000 == 0) cerr << '.' << flush; + ull c = i&1 ? lim - uniSmall(rng) : uni(rng); + ull a = expectSuccess ? c + uniTiny(rng) : c + uniSmall(rng); + ull l = int128_modmul(a, a, c); + ull r = useDoubles ? double_modmul(a, a, c) : modmul(a, a, c); + if (l != r) { + if (!expectSuccess) break; + cout << a << ' ' << c << endl; + cout << l << ' ' << r << endl; + abort(); + } + } +} + int main() { const ull limDoubles = 1ULL << 52; test(limDoubles, true, true); test((ull)(limDoubles * 1.02L), false, true); + testSq(limDoubles, true, true); + const ull lim = 7268172458553106874ULL; // floor((sqrt(177) - 7) / 16 * 2**64) test(lim, true, false); test((ull)(lim * 1.01L), false, false); - // test((ull)(lim * 1.001L), false); + // test((ull)(lim * 1.001L), false, false); + + testSq(lim, true, false); + testSq(lim, false, false); cout << "Tests passed!" << endl; }