Forestmy17's blog

By Forestmy17, history, 6 hours ago, In English

Maybe you are still using Fermat's Little Theorem to find inverses modulo a prime number $$$p$$$.

If p is prime and x is not divisible by p, then by Fermat's Little Theorem:

$$$x^{p-1}\equiv 1\pmod p$$$

Therefore,

$$$x^{p-2}\equiv x^{-1}\pmod p$$$

So the usual way to find the modular inverse is to use binary exponentiation:

int binpow(int a, int n) {
    int res = 1;
    while (n) {
        if (n&1) res = 1LL * res * a % mod;
        a = 1LL * a * a % mod;
        n >>= 1;
    }
    return res;
}
int inv(int x) {
    return binpow(x, mod - 2);
}

This works perfectly fine. But for such a common operation, it is nice to have something shorter.

For example, we can write the modular inverse like this:

int inv(int x) {
    return x == 1 ? 1 : mod - 1LL * (mod / x) * inv(mod % x) % mod;
}

Now let's prove why this formula works.

First, write mod using division with remainder:

$$$mod = (mod / x) * x + mod \% x$$$

Therefore,

$$$mod \% x = mod - (mod / x) * x$$$

Since we are working modulo mod, the value mod is congruent to 0, so:

$$$mod \% x\equiv -(mod/x) * x \pmod{mod}$$$

Now multiply both sides by $$$inv(mod\%x)$$$ $$$\big($$$ We know that $$$(mod \% x) * inv(mod \% x) \equiv 1 \pmod{mod}\big)$$$:

$$$1\equiv -(mod / x) * x * inv(mod \% x) \pmod{mod}$$$

Since mod is prime and x is not divisible by mod, division by x is valid modulo mod.

Dividing both sides by x, we get:

$$$x^{-1} \equiv -(mod / x) * inv(mod \% x) \pmod{mod}$$$

But $$$x^{-1}$$$ is exactly the modular inverse of x, so this gives us:

$$$inv(x) \equiv -(mod / x) * inv(mod \% x) \pmod{mod}$$$

Of course we want a non-negative value, so instead of returning the negative number directly, we add mod:

$$$inv(x) \equiv mod - (mod / x) * inv(mod \% x) \pmod{mod}$$$

So we have derived the recursive formula:

int inv(int x) {
    return x == 1 ? 1 : mod - 1LL * (mod / x) * inv(mod % x) % mod;
}

The recursion stops at x = 1 because inv(1) = 1.

The complexity is $$$O(\log mod)$$$, similar to the first solution, but the implementation is much shorter than binary exponentiation.

  • Vote: I like it
  • +13
  • Vote: I do not like it

»
6 hours ago, hide # |
 
Vote: I like it 0 Vote: I do not like it

Really Helpful

»
6 hours ago, hide # |
Rev. 2  
Vote: I like it 0 Vote: I do not like it

Is this intuition your own or have you learned from anywhere else?

  • »
    »
    6 hours ago, hide # ^ |
     
    Vote: I like it 0 Vote: I do not like it

    Actually, I saw one of the participants use this notation during a contest, and then decided to figure out where it came from.

»
6 hours ago, hide # |
 
Vote: I like it 0 Vote: I do not like it

Note that the time complexity is O(log mod) or more precisely O(log x) (x < mod) asx/2 > mod%x >= 1 so at every iteration step you are dividing the number by a factor more than 2. (T(x) <= T(x/2) + O(1))

»
5 hours ago, hide # |
 
Vote: I like it +6 Vote: I do not like it

i prefer to do binary exponentiation with a while loop to avoid the recursion overhead

»
4 hours ago, hide # |
 
Vote: I like it 0 Vote: I do not like it
const int mod = 998244353;
int mod_div(int x, int y) {
  int m = mod, u = 1, v = 0;
  while (m) swap(u -= y / m * v, v), swap(y %= m, m);
  assert(y == 1);
  return 1LL * x * (u + mod) % mod;
}
// int quotient = mod_div(x, y); // returns x * y^-1
»
3 hours ago, hide # |
 
Vote: I like it 0 Vote: I do not like it
// Finds x such that a * x ≡ 1 (mod mod)
// Requires gcd(a, mod) = 1
long long inv(long long a) {
    long long b = mod, u = 1, v = 0;
    while (b) {
        long long t = a / b;
        a -= t * b; swap(a, b);
        u -= t * v; swap(u, v);
    }
    return (u % mod + mod) % mod;
}