Fast modular multiplication

Правка en6, от orz, 2021-11-10 21:54:33

Consider the following problem: given three integers $$$x$$$, $$$y$$$ и $$$m$$$, $$$0 \leqslant x,y < m < 2^{32}$$$, calculate $$$xy\bmod m$$$. The easy way is just to multiply these numbers and apply the modulo operation: ~~~~~ uint32_t prod(const uint32_t x, const uint32_t y, const uint32_t m) { return x * y % m; } ~~~~~ As you might have guessed, this solution is wrong. The thing is that an overflow is possible in such a procedure: the operation x * y is performed in the typeuint32_t, and in fact the intermediate result of this operation will not be $$$xy$$$, but $$$xy\bmod2^{32}$$$. If after that we take the result modulo $$$m$$$, it may differ from the correct one:

$$$ \left(xy\bmod2^{32}\right)\bmod m\ne xy\bmod m. $$$

The way out is simple — you need to multiply in a larger type: ~~~~~ uint64_t prod_uint64(const uint64_t x, const uint64_t y, const uint64_t m) { return x * y % m; } ~~~~~ If you do this, then, since $$$xy<2^{64}$$$, this product will definitely not overflow, and after taking the result modulo, you will get the correct answer.

The question is: what if $$$x$$$, $$$y$$$ and $$$m$$$ can be greater than $$$2^{32}$$$? I suggest the following.

  1. Binary multiplication. Just like binary exponentiation, there is binary multiplication: to calculate $$$xy$$$, count $$$x\left\lfloor\frac y2\right\rfloor$$$, add this number to itself, and possibly add another $$$x$$$. This will spend $$$\mathcal O(\log y)$$$ actions, but among them there will be nothing but addition and subtraction! ~~~~~ uint64_t sum(const uint64_t x, const uint64_t y, const uint64_t m) { uint64_t ans = x + y; if (ans < x || ans >= m) ans -= m; return ans; } uint64_t prod_binary(const uint64_t x, const uint64_t y, const uint64_t m) { if (y <= 1) return y ? x : 0; uint64_t ans = prod_binary(x, y >> 1, m); ans = sum(ans, ans, m); if (y & 1) ans = sum(ans, x, m); return ans; } ~~~~~
  2. Multiplication via int128. To multiply two 32-bit numbers, you need a 64-bit intermediate variable. And to multiply two 64-bit numbers, you need a 128-bit variable! Modern 64-bit C++ compilers (except perhaps Microsoft® Visual C++®) have a special type __int128, which allows performing operations on 128-bit numbers. ~~~~~ int64_t prod_uint128(const uint64_t x, const uint64_t y, const uint64_t m) { return (unsigned __int128)x * y % m; } ~~~~~
  3. Multiplication using real type. What is $$$xy\bmod m$$$? This is actually $$$xy-cm$$$, where $$$c=\left\lfloor\frac{xy}m\right\rfloor$$$. Let's then try to calculate $$$c$$$, and from here we find $$$xy\bmod m$$$. At the same time, note that we do not need to find $$$c$$$ exactly. What happens if we accidentally count, say, $$$c-4$$$? Then, when calculating the remainder, we get $$$xy-(c-4)m=xy-cm+4m=xy\bmod m+4m$$$. At first glance, this is not what we need. But if $$$m$$$ is not too large and $$$ xy\bmod m+4m$$$ did not overflood the 64-bit type, then after that you can honestly take the remainder and get the answer.

    This translates into the following implementation: ~~~~~ uint64_t prod_double(const uint64_t x, const uint64_t y, const uint64_t m) { uint64_t c = (double)x * y / m; int64_t ans = int64_t(x * y — c * m) % int64_t(m); if (ans < 0) ans += m; return ans; } ~~~~~ ~~~~~ uint64_t prod_long_double(const uint64_t x, const uint64_t y, const uint64_t m) { uint64_t c = (long double)x * y / m; int64_t ans = int64_t(x * y — c * m) % int64_t(m); if (ans < 0) ans += m; return ans; } ~~~~~ double is accurate enough for this task if $$$x$$$, $$$y$$$ and $$$m$$$ are less than $$$2^{57}$$$. long double is enough for numbers less than $$$2^{63}$$$, but remember that long double must be 80-bit for this, and this is not true on all compilers: for example, in Microsoft® Visual C++® long double is the same as double.

    Please note that this method is not applicable if $$$m>2^{63}$$$: in this case, ans cannot be stored in int64_t, because, perhaps, $$$\mathtt{ans}\geqslant2^{63}$$$ and an overflow will occur, due to which the (ans < 0) branch will be executed and we will receive an incorrect answer.

    It can be seen that Microsoft® Visual C++® suffers from developmental delay lags behind other compilers in the availability of technical means for multiplying large numbers modulo, so if we want the function to work quickly on all compilers, it needs some fresh idea. Fortunately, such an idea was invented in 1960 by Anatoly Karatsuba.
  4. Karatsuba multiplication. The idea was originally used to quickly multiply long numbers. Namely, let $$$x$$$ and $$$y$$$ be two non-negative integers less than $$$N^2$$$. We divide them with the remainder by $$$N$$$: $$$x=Nx_1+x_0$$$, $$$y=Ny_1+y_0$$$. Then the required $$$xy$$$ can be found as $$$N^2x_1y_1+Nx_1y_0+Nx_0y_1+x_0y_0=N\cdot\bigl(N\cdot x_1y_1+\left(x_0+x_1\right)\left(y_0+y_1\right)-x_1y x_0y_0\bigr)+x_0y_0$$$. As you can see, this transformation reduced the multiplication of $$$x$$$ and $$$y$$$ 1) to $$$\mathcal O(1)$$$ additions and subtractions of numbers not exceeding $$$N^4$$$; 2) three multiplications of numbers not exceeding $$$2N$$$ (namely $$$x_0y_0$$$, $$$x_1y_1$$$ and $$$\left(x_0+x_1\right)\left(y_0+y_1\right) $$$); 3) to two multiplications by $$$N$$$ of numbers not exceeding $$$2N^2$$$.

    Item 1) is almost always very simple. In the case of long numbers, item 3) is also simple: you can take $$$N$$$, equal to a power of two, and then it is performed as a binary shift (operator << in C++). Therefore, in essence, Karatsuba reduced one multiplication of numbers less than $$$N^2$$$ to three multiplications of numbers less than $$$2N$$$. If these multiplications are also reduced by the Karatsuba method, according to the master theorem, the asymptotics of this method will be $$$\Theta\left(\log^{\log_23}N\right)$$$ instead of the naive $$$\Theta\left(\log^2N\right)$$$.

    But we do not need to use the recursion, because if the lengths of $$$x$$$ and $$$y$$$ are halved, we can already use prod_uint64 orprod_double. The difficulty in our case is point 3): choose such $$$N$$$ so that, firstly, it is less than $$$2^{32}$$$ or at least a bit more, and secondly, so that it can be quickly multiplied by numbers of order $$$N^2$$$. Both requirements are met if we take $$$N=\mathrm{round}\left(\sqrt m\right)$$$: indeed, then for $$$m<2^{64}$$$ we have $$$N<2^{32}$$$, and $$$\left|m_0\right|=\left|m-N^2\right|\leqslant N<2^{32}$$$; then $$$xN=\left(x_1N+x_0\right)N=x_1N^2+x_0N\equiv x_0N-x_1m_0\pmod m$$$, and both multiplications here are performed over numbers of order $$$N$$$.

    The attentive reader will notice that we have a serious problem here: we actually extract the square root of an integer. If you know how to use the multiplication of Karatsuba in this problem, bypassing this problem (including finding the square root quickly enough, writing faster or shorter code than mine), please write in the comments!

    Since finding the product $$$\left(x_0+x_1\right)\left(y_0+y_1\right)$$$ turned out to be extremely unpleasant (remember, prod_double does not work for $$$m>2^{63} $$$), I decided just to calculate $$$x_0y_1$$$ and $$$x_1y_0$$$ separately — so this is not Karatsuba's method in the true sense, since we spend four multiplications of numbers of order $$$N$$$. ~~~~~ uint64_t dif(const uint64_t x, const uint64_t y, const uint64_t m) { uint64_t ans = x — y; if (ans > x) ans += m; return ans; } bool check_ge_rounded_sqrt(const uint64_t m, const uint64_t r) { return ((r >= 1ull << 32) || r * (r + 1) >= m); } bool check_le_rounded_sqrt(const uint64_t m, const uint64_t r) { return (r == 0 || ((r <= 1ull << 32) && r * (r — 1) < m)); } bool check_rounded_sqrt(const uint64_t m, const uint64_t r) { return check_ge_rounded_sqrt(m, r) && check_le_rounded_sqrt(m, r); } uint64_t rounded_sqrt(const uint64_t m) { uint64_t r = floorl(.5 + sqrtl(m)); if (!check_ge_rounded_sqrt(m, r)) while (!check_ge_rounded_sqrt(m, ++r)); else if (!check_le_rounded_sqrt(m, r)) while (!check_le_rounded_sqrt(m, --r)); return r; } uint64_t prod_karatsuba_aux(const uint64_t x, const uint64_t N, const int64_t m0, const uint64_t m) { uint64_t x1 = x / N; uint64_t x0N = (x — N * x1) * N; if (m0 >= 0) return dif(x0N, x1 * (uint64_t)m0, m); else return sum(x0N, x1 * (uint64_t)-m0, m); } uint64_t prod_karatsuba(const test& t) { uint64_t x = t.x, y = t.y, m = t.modulo; uint64_t N = rounded_sqrt(t.modulo); int64_t m0 = m — N * N; uint64_t x1 = t.x / N; uint64_t x0 = t.x — N * x1; uint64_t y1 = t.y / N; uint64_t y0 = t.y — N * y1; uint64_t x0y0 = sum(x0 * y0, 0, m); uint64_t x0y1 = sum(x0 * y1, 0, m); uint64_t x1y0 = sum(x1 * y0, 0, m); uint64_t x1y1 = sum(x1 * y1, 0, m); return sum(prod_karatsuba_aux(sum(prod_karatsuba_aux(x1y1, N, m0, m), sum(x0y1, x1y0, m), m), N, m0, m), x0y0, m); } ~~~~~

It can be seen that in fact the only thing that the Karatsuba method gives us here is that if you find a large number $$$N$$$, by which you can quickly multiply, then you can multiply any two numbers modulo. In fact, if the modulus $$$m$$$ were fixed, and there were many queries for multiplication by this fixed modulus, then Karatsuba's method would be lightning fast, since the most expensive operation in it is the square root. Thus, I would like to take, for example, $$$N=2^{32}$$$ and do everything the same as in the previous paragraph, but without the square root. Alas, I haven't figured out how to multiply by $$$2^{32}$$$. One could write something like this: ~~~~~ uint64_t prod_double_small(const uint64_t x, const uint64_t y, const uint64_t m) { uint64_t c = (double)x * y / m; uint64_t ans = (x * y — c * m) % m; return ans; } ~~~~~ It calculates the product modulo, provided that uint64_t c = (double) x * y / m was calculated absolutely exactly. But it is not possible to guarantee that it will be accurately calculated, since $$$\frac{xy}m $$$ may well be $$$10^{-18}$$$ less than some integer, and the double type is not enough to notice it. This is the problem that the prod_karatsuba_aux function bypasses. If you somehow bypass it more directly, you are welcome in the comments.


Below are three tables for different compilers (may the MikeMirzayanov' name be famous, because thanks to the inappropriate use of his Polygon I completed it), in each table the rows correspond to different functions, the columns correspond to maximum allowed bitness of $$$x$$$, $$$y$$$ and $$$m$$$. If CE is specified, then the program will not compile with this compiler, and if WA, then it may give an incorrect answer. Otherwise, the cell features runtime on Intel® Core™ i3-8100 CPU @ 3.60GHz. The error is approximately equal to one to two nanoseconds, at the slowest functions it can go up to ten nanoseconds.

  1. Microsoft® Visual C++® 2010

    Method 32 bits 57 bits 63 bits 64 bits
    prod_uint64 7 ns WA WA WA
    prod_binary 477 ns 847 ns 889 ns 870 ns
    prod_uint128 CE CE CE CE
    prod_double 66 ns 95 ns WA WA
    prod_long_double 66 ns 98 ns WA WA
    prod_karatsuba 128 ns 125 ns 138 ns 139 ns

  2. GNU G++17

    Method 32 bits 57 bits 63 bits 64 bits
    prod_uint64 4 ns WA WA WA
    prod_binary 455 ns 774 ns 841 ns 845 ns
    prod_uint128 CE CE CE CE
    prod_double 26 ns 36 ns WA WA
    prod_long_double 29 ns 20 ns 19 ns WA
    prod_karatsuba 82 ns 81 ns 91 ns 88 ns

  3. GNU G++17 (64 bit)

    Method 32 bits 57 bits 63 bits 64 bits
    prod_uint64 8 ns WA WA WA
    prod_binary 313 ns 550 ns 604 ns 630 ns
    prod_uint128 17 ns 34 ns 30 ns 30 ns
    prod_double 23 ns 22 ns WA WA
    prod_long_double 23 ns 24 ns 23 ns WA
    prod_karatsuba 65 ns 65 ns 69 ns 66 ns

Therefore, the basic recipe is as follows: if unsigned __int128 is available, then use it, if an 80-bit long double is available, then it should always be enough, and otherwise, if double is enough, use double, else, apply the Karatsuba method.

If you wish, you can try to apply these ideas to the problems of the special contest.

История

 
 
 
 
Правки
 
 
  Rev. Язык Кто Когда Δ Комментарий
ru8 Русский orz 2022-06-16 20:17:53 4 Мелкая правка: '%D1%87).\n4. _Умно' -> '%D1%87).\n\n4. _Умно'
en7 Английский orz 2022-02-24 05:26:02 16
ru7 Русский orz 2022-02-24 05:24:33 21 Мелкая правка: 'остатка:\n~~~~~\nu' -> 'остатка:\n\n~~~~~\nu'
en6 Английский orz 2021-11-10 21:54:33 1 Tiny change: 't|=\left|mN^2\right|' -> 't|=\left|m-N^2\right|'
en5 Английский orz 2021-11-09 15:13:55 8 Tiny change: ' Karatsuba](https://e' -> ' Karatsuba(https://e'
en4 Английский orz 2021-11-09 08:42:08 21 Tiny change: 'tal delay<strike> la' -> 'tal delay</strike> la'
ru6 Русский orz 2021-11-09 08:41:10 21
en3 Английский orz 2021-11-09 08:20:46 0 (published)
ru5 Русский orz 2021-11-09 08:10:07 0 (опубликовано)
ru4 Русский orz 2021-11-09 08:09:54 27
ru3 Русский orz 2021-11-09 08:07:56 23 Мелкая правка: '0y0, m);\n}\n~~~~~\n' -> '0y0, m);\n }\n~~~~~\n'
en2 Английский orz 2021-11-09 08:07:52 418 Tiny change: 'rm{round}\!\left(\sqr' -> 'rm{round}\left(\sqr'
ru2 Русский orz 2021-11-09 08:00:41 533
en1 Английский orz 2021-11-09 07:51:02 15661 Initial revision for English translation (saved to drafts)
ru1 Русский orz 2021-11-09 07:33:41 15901 Первая редакция (сохранено в черновиках)