def _bin_pow(a, b, m):
c = 1
while b > 0:
if b % 2 == 1:
c = c * a % m
a = a * a % m
b //= 2
return c
# x^2 = a (mod p)
def _tonelli_shanks(a, p):
assert a % p != 0
assert p > 2
q = p - 1
s = 0
while q % 2 == 0:
q //= 2
s += 1
assert s > 0
if _bin_pow(a, (p - 1) // 2, p) != 1:
return []
z = 2
while _bin_pow(z, (p - 1) // 2, p) != p - 1:
z += 1
assert z < p
c = _bin_pow(z, q, p)
r = _bin_pow(a, (q + 1) // 2, p)
t = _bin_pow(a, q, p)
while t != 1:
i = 1
while i < s and _bin_pow(t, 2 ** i, p) != 1:
i += 1
assert i < s
b = _bin_pow(c, 2 ** (s - i - 1), p)
r = r * b % p
t = t * b * b % p
c = b * b % p
s = i
return sorted([r, p - r])
# x^2 = a (mod p^k)
def _generalized_tonelli_shanks(a, p, k):
assert a % p != 0
assert p > 2
assert k > 0
ans = []
m = p ** k
d = _bin_pow(a, (m - 2 * m // p + 1) // 2, m)
for x in _tonelli_shanks(a, p):
ans.append(_bin_pow(x, m // p, m) * d % m)
return sorted(ans)
# x^2 = a (mod 2^k)
def _discrete_sqrt_2(a, k):
assert k > 0
assert a > 0
if k == 1:
return [1]
if a % 8 != 1:
return []
ans = set()
for x in [1, 3]:
for i in range(3, k):
j = ((x * x - a) // (2 ** i) % 2 + 2) % 2
x = x + j * 2 ** (i - 1)
ans.add(x)
ans.add(2 ** k - x)
return sorted(list(ans))
# x^2 = 0 (mod p^k)
def _discrete_sqrt_zero(p, k):
assert p > 1
assert k > 0
ans = [0]
x = p ** ((k + 1) // 2)
for i in range(1, (p ** k) // x):
ans.append(x * i)
return ans
# x^2 = a (mod p^k)
def _discrete_sqrt_prime(a, p, k):
assert p > 1
assert k > 0
if a == 0:
return _discrete_sqrt_zero(p, k)
q, s = a, 0
while q % p == 0:
q //= p
s += 1
if s == 0:
if p == 2:
return _discrete_sqrt_2(a, k)
return _generalized_tonelli_shanks(a, p, k)
if s % 2 != 0:
return []
xq = _discrete_sqrt_prime(q, p, k)
x0 = p ** (s // 2)
d = p ** max((k + 1) // 2, k - s // 2 - (1 if p == 2 else 0))
ans = set()
for i in range((p ** k) // d):
for x in xq:
ans.add((x0 + i * d) * x % (p ** k))
return sorted(list(ans))
def _ex_gcd(a, b):
x, xp = 1, 0
y, yp = 0, 1
while b != 0:
q = a // b
a, b = b, a - q * b
x, xp = xp, x - q * xp
y, yp = yp, y - q * yp
return (a, x, y)
def _factorize(n):
assert n > 0
ans = []
i = 2
while i * i <= n:
if n % i != 0:
i += 1
continue
ans.append(i)
n //= i
if n > 1:
ans.append(n)
return ans
# x^2 = a (mod m)
def discrete_sqrt(a, m):
assert m > 0
ans = [0]
mans = 1
fm = _factorize(m)
it = 0
while it < len(fm):
jt = it
while jt < len(fm) and fm[jt] == fm[it]:
jt += 1
p = fm[it]
k = jt - it
pk = p ** k
x = _discrete_sqrt_prime(a % pk, p, k)
g, x1, _ = _ex_gcd(mans, pk)
assert g == 1
nans = []
for r0 in ans:
for r1 in x:
nans.append(((r0 + x1 * (r1 - r0) % pk * mans) % (mans * pk) + (mans * pk)) % (mans * pk))
ans = nans
mans *= pk
it = jt
return sorted(ans)