$\DeclareMathOperator{\ord}{ord}$↵
$\DeclareMathOperator{\mod}{mod}$↵
↵
↵
Редукция Монтгомери/Баррета и приложение к Теоретико-числовому Преобразованию Фурье↵
==================↵
↵
**Здесь** будет показано, как применить редукцию Монтгомери/Баррета к оптимизации Теоретико-числового Преобразования Фурье (NTT) без явных операций по модулю и созданию собственного класса модулярной арифметики, применимого в комбинаторных задачах. Будет приведена реализация NTT с минимальным числом операций по модулю, которая выигрывает у классического комплексного FFT не только в точности, но и в производительности на платформах без аппаратного ускорения. Мы сравним производительность различных оптимизаций NTT с комплексным FFT. Также будет показано, как добавить поддержку отрицательных коэффициентов в NTT.↵
↵
↵
1. Причины использования NTT вместо FFT↵
------------------↵
DFT (Дискретное Преобразование Фурье) широко используется в задаче умножения многочленов. DFT может быть вычисленокв любом поле, где определён корень n-й степени, т.е. элемент с порядком (показателем) n. Например, это может быть поле комплексных чисел (используемое в FFT) или кольцо вычетов по простому модулю (NTT). Известная реализация — Быстрое Преобразование Фурье — вычисляет DFT за $O(n\log{n})$ в любом поле (отличается только арифметика). Реализация FFT по модулю $p$ называется Теоретико-числовым Преобразованием Фурье (NTT). Просто FFT будем называть алгоритм в комплексном поле. Базовый алгоритм без оптимизаций рассмотрен во многих статьях, например, на <a href="https://cp-algorithms.com/algebra/fft.html">CP-algorithms</a>. NTT почти не используют, потому что его принятая реализация с модулярной арифметикой по большому модулю является очень медленной без аппаратного ускорения. Здесь будет показано, как реализовать быстрый NTT с неявными операциями по модулю.↵
↵
Конечно, есть удобный FFT с комплесным типом double, который используется в перемножении многочленов (и не только), в т.ч. для решения задач на Codeforces. Но FFT имеет очень плохую точность из-за ограниченной паимяти машинной реализации арифметики с плавающей точкой, и он довольно медленный (даже с нерекурсивной реализацией). Следовательно, если даны многочлены с неотрицательными большими коэффициентами для перемножения, есть возможность использования NTT: перемножить многочлены по модулю $p$ и восстановить реальные коэффициенты до $10^{18}$ (например, с помощью двух 32-битных модулей и Китайской Теоремы об Остатках для решения систем сравнений по модулям или просто используя большой модуль с кастом к 128-битному целочисленному типу). ↵
↵
**Главное преимущество** NTT над FFT — это точность: через NTT мы получаем абсолютно точные коэффициенты произведения в диапазоне $[0; 10^{18}]$, а использование FFT приводит к ошибке 10-100 при использовании double и 1-10 при long double (например, [submission:259759436] c double получает WA43 в задаче, сводимой к умножению многочленов, когда как [submission:259759880] с long double получает OK). ↵
↵
NTT не работает быстрее FFT "из коробки", т.к. стандартная реализация NTT использует много операций по модулю, так что на многих системах (особенно без встроенного ускорения через SIMD) оно работает медленно. Но оптимизации описанные ниже помогут нам избавиться от явных операций взятия остатка, и NTT сможет работать заметно быстрее FFT. Для этого мы можем использовать редукцию Монтгомери и редукцию Баррета. ↵
↵
↵
2. Редукция Монтгомери↵
------------------↵
Редукция Монтгомери — это подход к организации арифметики по фиксированному модулю без явного деления с остатком (%). Главная идея в том, что мы можем свести деление с остатком на $p$ к делению с остатком на $2^k$, которое может быть реализовано через битовые операции. Для ознакомления с самим алгоритмом и доказательством можно <a href="https://en.algorithmica.org/hpc/number-theory/montgomery/">в статье</a>. Редукция Монтгомери обязывает перед выполнением операций по модулю приведение к специальной форме пространства Монтгомери. Приведение к этой форме требует явного деления с остатком, но в NTT особый случай: все числа, с которыми мы будем оперировать по модулю в алгоритме — это только коэффициенты исходных многочленов, корень и некоторые обратные (n, корень). Мы можем сделать это за $O(n)$ перед вызовом NTT и после него (чтобы привести в нормальную форму). Следовательно, можно избавиться от модулярной арифметики внутри NTT. Единственный нюанс в том, что при использовании большого модуля ($~10^{18}$) мы вынуждены использовать __int128. Из-за изобилия операций с многобитными числами была даже выпущена <a href="https://networkbuilders.intel.com/docs/networkbuilders/intel-avx-512-fast-modular-multiplication-technique-technology-guide-1710916893.pdf">статья Intel</a> для ускорения редукции Монтгомери через AVX. ↵
↵
Но если мы хотим уменьшить число операций с многобитными числами без потери в размерах коэффициентах, можно вычислить NTT по двум модулям, используя только 64-битные числа (параллельно через 1 вызов) и решить систему сравнений по простым модулям через Китайскую Теорему об Остатках. Тогда __int128 будет использован только в КТО. ↵
↵
**P.S.** В будущем статья будет дополнятся полностью русскоязычными разборами модулярных редукций. ↵
↵
↵
3. Редукция Баррета↵
------------------↵
Другой способ избежать явное деление с остатком — это редукция Баррета. Она основана на приближении ответа через деление с округлением вниз с предподсчитанным множителем определённого вида. Конкретно, преобразуем выражение для остатка: ↵
$$ a \mod p = a - \lfloor {a \over p} \rfloor \cdot p $$ ↵
Выберем $m$, так что: ↵
$$ {1 \over p} = {m \over 2^k} \longleftrightarrow m = {2^k \overnp} $$ ↵
Теперь при делении будем подставлять $1 \over p$. Если взять достаточно большой множитель $2^k$ для приближения (порядка $p^2$) ответ для чисел нашего диапазона (до $10^{18}$) будет точным. С полным доказательством можно ознакомиться <a href="https://www.nayuki.io/page/barrett-reduction-algorithm">здесь</a>. Она даёт такой же функционал, что и редукция Монтгомери, но не требует приведение к специальной форме. Однако мы в каждой операции должны делать битовый сдвиг на больше, чем $p^2$ и делать умножения, поэтому при использовании модуля около $10^{18}$ необходимо использовать 256-битный целочисленный тип, поэтому операции дороже. ↵
↵
↵
4. Выбираем модули↵
------------------↵
Во первых, мы должны выбрать модуль такой модуль $p$, что $n | (p-1): n=2^k$, чтобы существовал корень n-й степени из 1 (порядок элемента делит порядок циклической группы по теореме Лагранжа). Т.к. по модулю $p$ существует первообразный корень $g$, мы можем взять $g^{p-1 \over n}$ в качестве корня n-й степени. Для конкретного $p$ эти корни могут быть легко найдены с помощью бинарного возведения в степень и алгоритмом поиска первообразного корня за $O(Ans \cdot \log^2(p))$, $Ans$ — размер первообразного корня в конечном поле — по предположению гипотезы всегда маленький. <a href="https://cp-algorithms.com/algebra/primitive-root.html">Здесь</a> имеется доказательство корректности алгоритма. Например, в нашей задаче для перемножения многочленов размером до $2^{24}$ с коэффициентами произведения до $2 \cdot 10^{18}$ можно взять $p=2524775926340780033$. Первообразный корень 3. ↵
↵
<spoiler summary="Primitive root searching algorithm">↵
```c++↵
// Primitive root modulo n↵
// (generator of cyclic group with n-1 elements)↵
int generator(int n) {↵
vector<int> fact;↵
int phi = euler_totient(n); // for prime equals n-1↵
int m = phi;↵
for (int d = 2; d*d <= m; ++d)↵
if (m%d == 0) {↵
fact.push_back(d);↵
while (m%d == 0)↵
m /= d;↵
}↵
if (m > 1)↵
fact.push_back(m);↵
for (int root = 2; root <= n; ++root) {↵
bool found = true;↵
for (auto d : fact)↵
if (bin_pow(root, phi / d, n) == 1) {↵
found = false;↵
break;↵
}↵
if (found)↵
return root;↵
}↵
return n == 1 ? 1 : -1;↵
}↵
↵
```↵
</spoiler>↵
↵
↵
↵
5. Реализации↵
------------------↵
Будем использовать нерекурсивную реализацию для FFT (NTT) с дополнительной структурой для арифметики с редукцией Монтгомери/Баррета. Удобно использовать беззнаковые типы при работе с редукциями из-за трюка с переполнением: оно не ведёт к неопределённому поведению, а возвращает ответ по модулю $2^{64}$. ↵
↵
↵
<spoiler summary="NTT with Montgomery multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
↵
struct montgomery {↵
uint64_t n, nr;↵
↵
constexpr montgomery(uint64_t n) : n(n), nr(1) {↵
// log(2^64) = 6↵
for (int i = 0; i < 6; i++)↵
nr *= 2 - n * nr;↵
}↵
↵
[[nodiscard]]↵
uint64_t reduce(__uint128_t x) const {↵
uint64_t q = __uint128_t(x) * nr;↵
uint64_t m = ((__uint128_t) q * n) >> 64;↵
uint64_t res = (x >> 64) + n - m;↵
if (res >= n)↵
res -= n;↵
return res;↵
}↵
↵
[[nodiscard]]↵
uint64_t multiply(uint64_t x, uint64_t y) const {↵
return reduce((__uint128_t) x * y);↵
}↵
↵
[[nodiscard]]↵
uint64_t transform(uint64_t x) const {↵
return (__uint128_t(x) << 64) % n;↵
}↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
const uint64_t mod = 2524775926340780033, gen = 3;↵
//const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, montgomery& red, ↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, red.transform(1));↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = red.multiply(W[i-1], w);↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], inv_n);↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
montgomery red(mod);↵
for (auto& x : a)↵
x = red.transform(x);↵
for (auto& x : b)↵
x = red.transform(x);↵
↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
uint64_t inv_n = red.transform(bin_pow(n, mod-2, mod));↵
uint64_t root = red.transform(bin_pow(gen, (mod-1)/n, mod));↵
uint64_t inv_root = red.transform(bin_pow(red.reduce(root), mod-2, mod));↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, red, inv_n, root, inv_root, false);↵
ntt(b, rev, red, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], b[i]);↵
ntt(a, rev, red, inv_n, root, inv_root, true);↵
↵
for (auto& x : a)↵
x = red.reduce(x);↵
}↵
↵
```↵
</spoiler>↵
↵
↵
Реализация с редукцией Баррета может быть использована только с модулем $<2^{32}$ без 256-битных чисел. Т.к. редукция Баррета а приори имеет более дорогие операции, здесь не будет приведена реализация с uint256. В коде можно использовать uint32 вместо uint64, но ради равенства условий в тестировании будут приведены одинаковые форматы и в реализации с Монтгомери и с Барретом. ↵
↵
<spoiler summary="NTT with Barret multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
↵
struct barret {↵
uint64_t n, s;↵
__uint128_t f;↵
↵
constexpr barret(uint64_t _n) {↵
n = _n;↵
s = 64;↵
f = (__uint128_t(1) << s) / n;↵
}↵
↵
[[nodiscard]]↵
uint64_t reduce(__uint128_t x) const {↵
auto t = (uint64_t)(x - ((x * f) >> s) * n);↵
if (t < n)↵
return t;↵
return t - n;↵
}↵
↵
[[nodiscard]]↵
uint64_t multiply(uint64_t x, uint64_t y) const {↵
return reduce((__uint128_t) x * y);↵
}↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
//const uint64_t mod = 2524775926340780033, gen = 3;↵
const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, barret& red, ↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = red.multiply(W[i-1], w);↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], inv_n);↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
barret red(mod);↵
uint64_t inv_n = bin_pow(n, mod-2, mod);↵
uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
uint64_t inv_root = bin_pow(root, mod-2, mod);↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, red, inv_n, root, inv_root, false);↵
ntt(b, rev, red, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], b[i]);↵
ntt(a, rev, red, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵
↵
↵
↵
Также здесь будет приведена реализация FFT со своей структурой для комплексных чисел. Она может быть использована с double / long double. ↵
↵
<spoiler summary="FFT with custom cmpls">↵
```c++↵
struct _cmpl {↵
double a, b;↵
_cmpl(double a = 0, double b = 0) : a(a), b(b) {}↵
↵
const _cmpl operator + (const _cmpl &c) const↵
{ return _cmpl(a + c.a, b + c.b); }↵
↵
const _cmpl operator - (const _cmpl &c) const↵
{ return _cmpl(a - c.a, b - c.b); }↵
↵
const _cmpl operator * (const _cmpl &c) const↵
{ return _cmpl(a * c.a - b * c.b, a * c.b + b * c.a); }↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
void fft(vector<_cmpl>& a, vector<int>& rev, bool invert) {↵
int n = a.size(), h = -1;↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
double alpha = 2 * atan2(0.00, -1.00) / n * (invert ? -1 : 1);↵
_cmpl w1(cos(alpha), sin(alpha));↵
vector<_cmpl> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = W[i - 1] * w1;↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
_cmpl t = a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))];↵
a[j ^ (1 << i)] = a[j] - t;↵
a[j] = a[j] + t;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = _cmpl(a[i].a / n, a[i].b / n);↵
}↵
↵
↵
void mul(vector<_cmpl>& a, vector<_cmpl>& b, vector<uint64_t>& res) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
auto rev = bit_sort(n);↵
fft(a, rev, false);↵
fft(b, rev, false);↵
for (int i = 0; i < n; i++)↵
a[i] = a[i] * b[i];↵
fft(a, rev, true);↵
res.resize(n);↵
for (int i = 0; i < n; i++)↵
res[i] = (uint64_t)(a[i].a + 0.1);↵
}↵
```↵
</spoiler>↵
↵
↵
NTT без оптимизаций: ↵
↵
<spoiler summary="NTT without opt">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
//const uint64_t mod = 998244353, gen = 3;↵
const uint64_t mod = 2524775926340780033, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev,↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = (__uint128_t) W[i-1] * w % mod;↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = (__uint128_t) a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))] % mod;↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = (__uint128_t) a[i] * inv_n % mod;↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
uint64_t inv_n = bin_pow(n, mod-2, mod);↵
uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
uint64_t inv_root = bin_pow(root, mod-2, mod);↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, inv_n, root, inv_root, false);↵
ntt(b, rev, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = (__uint128_t) a[i] * b[i] % mod;↵
ntt(a, rev, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵
↵
↵
↵
6. Производительность↵
------------------↵
Мы будет тестировать 5 алгоритмов: FFT с double и long double, NTT без и с редукциями Монтгомери и Баррета. ↵
↵
Во-первых, посмотрим на [problem:993E]. Важно отметить, что NTT с редукцией Баррета и FFT с double получили wa43 (но прошли тесты на других тестах с большими входными данными), потому что FFT с double имеет плохую точность, а NTT с редукцией Баррета по модулю $998244353$ не может хранить большой ответ (NTT с Барретом и без uint256 может быть использовано только для небольших коэффициентов). ↵
↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>[submission:261258380], 311ms, BAD PRECISION</th>↵
<th>[submission:261256931], 702ms, OK</th>↵
<th>[submission:261256487], 281ms, OK</th>↵
<th>[submission:261257532], 312ms, BAD LIMITS</th>↵
<th>[submission:261257532], 859ms, OK</th>↵
</tr>↵
</table>↵
↵
Теперь протестируем перемножение многочленов размером $10^6$ со случайными коэффициентами на запуске Codeforces: ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>1407ms</th>↵
<th>2434ms</th>↵
<th>1374ms</th>↵
<th>1437ms</th>↵
<th>4122ms</th>↵
</tr>↵
<tr>↵
<th>1286ms</th>↵
<th>2426ms</th>↵
<th>1256ms</th>↵
<th>1675ms</th>↵
<th>3683ms</th>↵
</tr>↵
<tr>↵
<th>1052ms</th>↵
<th>2286ms</th>↵
<th>1153s</th>↵
<th>1320ms</th>↵
<th>3518ms</th>↵
</tr>↵
<tr>↵
<th>1079ms</th>↵
<th>2359ms</th>↵
<th>1309ms</th>↵
<th>1380ms</th>↵
<th>3961ms</th>↵
</tr>↵
<tr>↵
<th>1217ms</th>↵
<th>2791ms</th>↵
<th>1476ms</th>↵
<th>1780ms</th>↵
<th>3400ms</th>↵
</tr>↵
<tr>↵
<th>1241ms</th>↵
<th>2191ms</th>↵
<th>1125ms</th>↵
<th>1337ms</th>↵
<th>3512ms</th>↵
</tr>↵
</table>↵
↵
Теперь локальные тесты на Ryzen 5 5650u (AVX), 8Gb RAM 4266MHz, Debian 12, GCC 12.2.0. Параметры компиляции: `g++ -Wall -Wextra -Wconversion -static -DONLINE_JUDODGE -O2 -std=c++20 fftest.cc -o fftest` как на Codeforces. ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>691ms</th>↵
<th>2339ms</th>↵
<th>352ms</th>↵
<th>401ms</th>↵
<th>398ms</th>↵
</tr>↵
<tr>↵
<th>700ms</th>↵
<th>2328ms</th>↵
<th>411ms</th>↵
<th>500ms</th>↵
<th>487ms</th>↵
</tr>↵
<tr>↵
<th>729ms</th>↵
<th>2036ms</th>↵
<th>359ms</th>↵
<th>446ms</th>↵
<th>396ms</th>↵
</tr>↵
<tr>↵
<th>698ms</th>↵
<th>2164ms</th>↵
<th>380ms</th>↵
<th>456ms</th>↵
<th>472ms</th>↵
</tr>↵
<tr>↵
<th>734ms</th>↵
<th>2284ms</th>↵
<th>407ms</th>↵
<th>462ms</th>↵
<th>440ms</th>↵
</tr>↵
<tr>↵
<th>737ms</th>↵
<th>2148ms</th>↵
<th>396ms</th>↵
<th>459ms</th>↵
<th>434ms</th>↵
</tr>↵
</tr>↵
</table>↵
↵
На сервере Codeforces одинаковая производительность у Montgomery + NTT и FFT + doubles, затем идёт NTT + Barret, затем FFT + long doubles, затем NTT без оптимизаций. На локальной системе с современным процессором с поддержкой SIMD лидером является Montgomery + NTT. Причём NTT на локальной системе работает намного быстрее **даже без оптимизаций**. Я точно не могу определить, какие технологии отвечают за быструю 128-битную модулярную арифметику, но, скорее всего, это часть SIMD. Вообще, NTT хорошо поддаётся аппаратному ускорению, напирмер, <a href="https://link.springer.com/chapter/10.1007/978-3-030-78713-4_6">в статье</a> описано ускорение NTT на FPGA. ↵
↵
Как мы видим, производительность NTT по большому модулю сильно упирается в платформу из-за наличия большого числа операций с высокобитными числами, а именно деления с остатком 128-битных чисел. Это операция сильно упирается в аппаратную реализацию, поэтому из-за неё на Codeforces NTT без оптимизаций оказался самым медленным, но на CPU с ускорением SIMD он выигрывает у FFT. Тем не менее, разница может быть нивелирована засчёт явной реализации **редукции**, которая использует только умножение 128-битных чисел. Про проблемы 128-битного деления изложено <a href="https://danlark.org/2020/06/14/128-bit-division/">в статье</a>. ↵
↵
В итоге NTT + Montgomery является довольно универсальным выбором на всех платформах. Единственный нюанс в том, что NTT изначально не поддерживает отрицательные коэффициенты (т.к. всё берётся по модулю), но это можно исправить. ↵
↵
↵
7. Отрицательные коэффициенты↵
------------------↵
Рассмотрим $A(x)$ и $B(x)$ с любыми целыми коэффициентами, и мы хотим получить коэффициенты произведения через NTT. Пусть $C(x)$ — такой многочлен, что любой коэффициент у $A(x)+C(x)$, $B(x)+C(x)$ и $A(x)+B(x)+C(x)$ является неотрицательным. ↵
Заметим, что ↵
$$ A \cdot B = (A+C) \cdot (B+C) - C \cdot (A+B+C) $$↵
Так как каждое слагаемое в сумме имеет неотрицательные коэффициенты, $A(x) \cdot B(x)$ можно найти через 2 вызова умножения многочленов через NTT. $C(x)$ можно найти жадно за $O(n)$.↵
↵
↵
8. Класс для модулярной арифметики↵
------------------↵
Удобно создать класс для арифметики в $\mathbb{Z}/p\mathbb{Z}$ и какой-то редукцией. Ниже приведён класс с редукцией Баррета с методами для НОДа, степени, инверсий и корней (класс можно расширить до арифметики в любом конечном поле). P.S. Если мы хотим получить максимальную производительность и работаем с специфичным множеством чисел, можно использовать редукцию Монтгомери, потому что Монтгомери быстрее. ↵
↵
↵
<spoiler summary="Modular arithmetic class">↵
```c++↵
class zpz {↵
public:↵
static void init(uint32_t m) {↵
mod = m;↵
shift = 2*(32 - __builtin_clz(m));↵
factor = (uint64_t(1) << shift) / mod;↵
gen = 0;↵
}↵
↵
static uint32_t ext_gcd(uint32_t a, uint32_t b, uint64_t& x, uint64_t& y) {↵
if (a < b)↵
return ext_gcd(b, a, y, x);↵
if (b == 0) {↵
x = 1;↵
y = 0;↵
return a;↵
}↵
uint64_t x1, y1;↵
uint32_t g = ext_gcd(b, a%b, x1, y1);↵
x = y1;↵
y = x1 - (a/b)*y1;↵
return g;↵
}↵
↵
static zpz pow(zpz a, uint32_t n) {↵
zpz res = 1;↵
while (n) {↵
if (n & 1)↵
res *= a;↵
a *= a;↵
n >>= 1;↵
}↵
return res;↵
}↵
↵
static zpz inv(zpz a) {↵
if (inverses.find(a()) == inverses.end()) {↵
uint64_t x, y;↵
ext_gcd(a(), mod, x, y);↵
inverses[a()] = reduce(x + mod);↵
}↵
return inverses[a()];↵
}↵
↵
static zpz root(zpz a, int n) {↵
if (gen == 0) {↵
vector<uint32_t> fact;↵
// int phi = euler_totient(n);↵
uint32_t phi = mod-1;↵
uint32_t m = phi;↵
for (uint32_t d = 2; d*d <= m; ++d)↵
if (m%d == 0) {↵
fact.push_back(d);↵
while (m%d == 0)↵
m /= d;↵
}↵
if (m > 1)↵
fact.push_back(m);↵
for (uint32_t rt = 2; rt < mod; ++rt) {↵
bool found = true;↵
for (auto d : fact)↵
if (pow(zpz(rt), phi / d) == 1) {↵
found = false;↵
break;↵
}↵
if (found) {↵
gen = rt;↵
break;↵
}↵
}↵
gen = mod == 1 ? 1 : throw exception();↵
}↵
↵
return pow(zpz(gen), (mod-1)/n);↵
}↵
↵
static void get_all_modular_inverses() {↵
inverses[1] = 1;↵
for (int k = 2; k < mod; ++k)↵
inverses[k] = -1LL * (mod / k) * inverses[mod % k] % mod + mod;↵
}↵
↵
↵
zpz(uint32_t x) : val(reduce(x)) {}↵
↵
uint32_t operator () () const { return val; }↵
↵
zpz& operator =(uint64_t x) { val = reduce(x); return *this; }↵
↵
zpz& operator =(const zpz& x) { val = x(); return *this; }↵
↵
zpz& operator +=(const zpz& x) { val = reduce(val + x()); return *this; }↵
zpz& operator -=(const zpz& x) { val = reduce(val - x() + mod); return *this; }↵
zpz& operator *=(const zpz& x) { val = reduce((uint64_t) val * x()); return *this; }↵
↵
zpz& operator +=(uint64_t x) { return *this += zpz(x); }↵
zpz& operator -=(uint64_t x) { return *this -= zpz(x); }↵
zpz& operator *=(uint64_t x) { return *this *= zpz(x); }↵
↵
zpz& operator /=(const zpz& x) { return *this *= inv(x); }↵
zpz& operator /=(uint64_t x) { return *this /= zpz(x); }↵
zpz operator /(uint64_t x) { zpz cur = *this; return cur /= x; }↵
↵
zpz& operator ++() { return *this += 1; }↵
zpz& operator --() { return *this -= 1; }↵
↵
zpz operator ++(int unused) { zpz z(*this); ++(*this); return z; }↵
zpz operator --(int unused) { zpz z(*this); --(*this); return z; }↵
↵
friend zpz operator +(zpz x, const zpz& y) { return x += y; }↵
friend zpz operator *(zpz x, const zpz& y) { return x *= y; }↵
friend zpz operator -(zpz x, const zpz& y) { return x -= y; }↵
friend zpz operator /(zpz x, const zpz& y) { return x /= y; }↵
↵
friend zpz operator +(zpz x, uint32_t y) { return x += y; }↵
friend zpz operator *(zpz x, uint32_t y) { return x *= y; }↵
friend zpz operator -(zpz x, uint32_t y) { return x -= y; }↵
friend zpz operator /(zpz x, uint32_t y) { return x /= y; }↵
↵
friend zpz operator +(uint32_t x, zpz y) { return y += x; }↵
friend zpz operator *(uint32_t x, zpz y) { return y *= x; }↵
↵
friend zpz operator -(uint32_t x, const zpz& y) { zpz z(x); return z -= y; }↵
friend zpz operator /(uint32_t x, const zpz& y) { zpz z(x); return z /= y; }↵
↵
bool operator <(const zpz& x) const { return val < x(); }↵
bool operator ==(const zpz& x) const { return val == x(); }↵
bool operator >(const zpz& x) const { return val > x(); }↵
bool operator !=(const zpz& x) const { return val != x(); }↵
bool operator <=(const zpz& x) const { return val <= x(); }↵
bool operator >=(const zpz& x) const { return val >= x(); }↵
↵
bool operator <(uint32_t x) const { return val < x; }↵
bool operator ==(uint32_t x) const { return val == x; }↵
bool operator >(uint32_t x) const { return val > x; }↵
bool operator !=(uint32_t x) const { return val != x; }↵
bool operator <=(uint32_t x) const { return val <= x; }↵
bool operator >=(uint32_t x) const { return val >= x; }↵
↵
friend istream& operator >> (istream& input, zpz& x)↵
{↵
uint32_t z;↵
input >> z,↵
x = zpz(z);↵
return input;↵
}↵
↵
friend ostream& operator << (ostream& output, const zpz& x)↵
{↵
return output << x();↵
}↵
↵
↵
private:↵
static uint32_t mod, shift;↵
static uint64_t factor;↵
↵
static gp_hash_table<uint32_t, uint32_t> inverses;↵
static uint32_t gen;↵
↵
[[nodiscard]]↵
static uint32_t reduce(uint64_t x) {↵
auto t = (uint32_t)(x - (((__uint128_t) x * factor) >> shift) * mod);↵
if (t < mod)↵
return t;↵
return t - mod;↵
}↵
↵
uint32_t val;↵
};↵
↵
uint32_t zpz::mod, zpz::shift;↵
uint64_t zpz::factor;↵
gp_hash_table<uint32_t, uint32_t> zpz::inverses;↵
uint32_t zpz::gen;↵
```↵
</spoiler>↵
↵
↵
9. Заключение↵
------------------↵
В результате мы сравнили различные редукции для модульной арифметики и выяснили, что редукция Монтгомери имеет наилучшую производительность, если мы заранее знаем набор целых чисел. Следовательно, NTT с редукцией Монтгомери можно использовать вместо FFT в задачах с умножением многочленов (даже для полиномов с отрицательными коэффициентами). Здесь приведена реализация NTT с редукцией Монтгомери, которая показывает лучшую производительность даже по сравнению с FFT с double, особенно с аппаратной оптимизацией современных x64 процессоров (SIMD?). Остается открытым вопрос, какие именно аппаратные оптимизации позволяют NTT работать быстрее на современных процессорах "из коробки".
$\DeclareMathOperator{\mod}{mod}$↵
↵
↵
Редукция Монтгомери/Баррета и приложение к Теоретико-числовому Преобразованию Фурье↵
==================↵
↵
**Здесь** будет показано, как применить редукцию Монтгомери/Баррета к оптимизации Теоретико-числового Преобразования Фурье (NTT) без явных операций по модулю и созданию собственного класса модулярной арифметики, применимого в комбинаторных задачах. Будет приведена реализация NTT с минимальным числом операций по модулю, которая выигрывает у классического комплексного FFT не только в точности, но и в производительности на платформах без аппаратного ускорения. Мы сравним производительность различных оптимизаций NTT с комплексным FFT. Также будет показано, как добавить поддержку отрицательных коэффициентов в NTT.↵
↵
↵
1. Причины использования NTT вместо FFT↵
------------------↵
DFT (Дискретное Преобразование Фурье) широко используется в задаче умножения многочленов. DFT может быть вычислено
↵
Конечно, есть удобный FFT с комплесным типом double, который используется в перемножении многочленов (и не только), в т.ч. для решения задач на Codeforces. Но FFT имеет очень плохую точность из-за ограниченной па
↵
**Главное преимущество** NTT над FFT — это точность: через NTT мы получаем абсолютно точные коэффициенты произведения в диапазоне $[0; 10^{18}]$, а использование FFT приводит к ошибке 10-100 при использовании double и 1-10 при long double (например, [submission:259759436] c double получает WA43 в задаче, сводимой к умножению многочленов, когда как [submission:259759880] с long double получает OK). ↵
↵
NTT не работает быстрее FFT "из коробки", т.к. стандартная реализация NTT использует много операций по модулю, так что на многих системах (особенно без встроенного ускорения через SIMD) оно работает медленно. Но оптимизации описанные ниже помогут нам избавиться от явных операций взятия остатка, и NTT сможет работать заметно быстрее FFT. Для этого мы можем использовать редукцию Монтгомери и редукцию Баррета. ↵
↵
↵
2. Редукция Монтгомери↵
------------------↵
Редукция Монтгомери — это подход к организации арифметики по фиксированному модулю без явного деления с остатком (%). Главная идея в том, что мы можем свести деление с остатком на $p$ к делению с остатком на $2^k$, которое может быть реализовано через битовые операции. Для ознакомления с самим алгоритмом и доказательством можно <a href="https://en.algorithmica.org/hpc/number-theory/montgomery/">в статье</a>. Редукция Монтгомери обязывает перед выполнением операций по модулю приведение к специальной форме пространства Монтгомери. Приведение к этой форме требует явного деления с остатком, но в NTT особый случай: все числа, с которыми мы будем оперировать по модулю в алгоритме — это только коэффициенты исходных многочленов, корень и некоторые обратные (n, корень). Мы можем сделать это за $O(n)$ перед вызовом NTT и после него (чтобы привести в нормальную форму). Следовательно, можно избавиться от модулярной арифметики внутри NTT. Единственный нюанс в том, что при использовании большого модуля ($~10^{18}$) мы вынуждены использовать __int128. Из-за изобилия операций с многобитными числами была даже выпущена <a href="https://networkbuilders.intel.com/docs/networkbuilders/intel-avx-512-fast-modular-multiplication-technique-technology-guide-1710916893.pdf">статья Intel</a> для ускорения редукции Монтгомери через AVX. ↵
↵
Но если мы хотим уменьшить число операций с многобитными числами без потери в размерах коэффициентах, можно вычислить NTT по двум модулям, используя только 64-битные числа (параллельно через 1 вызов) и решить систему сравнений по простым модулям через Китайскую Теорему об Остатках. Тогда __int128 будет использован только в КТО. ↵
↵
**P.S.** В будущем статья будет дополнятся полностью русскоязычными разборами модулярных редукций. ↵
↵
↵
3. Редукция Баррета↵
------------------↵
Другой способ избежать явное деление с остатком — это редукция Баррета. Она основана на приближении ответа через деление с округлением вниз с предподсчитанным множителем определённого вида. Конкретно, преобразуем выражение для остатка: ↵
$$ a \mod p = a - \lfloor {a \over p} \rfloor \cdot p $$ ↵
Выберем $m$, так что: ↵
$$ {1 \over p} = {m \over 2^k} \longleftrightarrow m = {2^k \over
Теперь при делении будем подставлять $1 \over p$. Если взять достаточно большой множитель $2^k$ для приближения (порядка $p^2$) ответ для чисел нашего диапазона (до $10^{18}$) будет точным. С полным доказательством можно ознакомиться <a href="https://www.nayuki.io/page/barrett-reduction-algorithm">здесь</a>. Она даёт такой же функционал, что и редукция Монтгомери, но не требует приведение к специальной форме. Однако мы в каждой операции должны делать битовый сдвиг на больше, чем $p^2$ и делать умножения, поэтому при использовании модуля около $10^{18}$ необходимо использовать 256-битный целочисленный тип, поэтому операции дороже. ↵
↵
↵
4. Выбираем модули↵
------------------↵
Во первых, мы должны выбрать модуль такой модуль $p$, что $n | (p-1): n=2^k$, чтобы существовал корень n-й степени из 1 (порядок элемента делит порядок циклической группы по теореме Лагранжа). Т.к. по модулю $p$ существует первообразный корень $g$, мы можем взять $g^{p-1 \over n}$ в качестве корня n-й степени. Для конкретного $p$ эти корни могут быть легко найдены с помощью бинарного возведения в степень и алгоритмом поиска первообразного корня за $O(Ans \cdot \log^2(p))$, $Ans$ — размер первообразного корня в конечном поле — по предположению гипотезы всегда маленький. <a href="https://cp-algorithms.com/algebra/primitive-root.html">Здесь</a> имеется доказательство корректности алгоритма. Например, в нашей задаче для перемножения многочленов размером до $2^{24}$ с коэффициентами произведения до $2 \cdot 10^{18}$ можно взять $p=2524775926340780033$. Первообразный корень 3. ↵
↵
<spoiler summary="Primitive root searching algorithm">↵
```c++↵
// Primitive root modulo n↵
// (generator of cyclic group with n-1 elements)↵
int generator(int n) {↵
vector<int> fact;↵
int phi = euler_totient(n); // for prime equals n-1↵
int m = phi;↵
for (int d = 2; d*d <= m; ++d)↵
if (m%d == 0) {↵
fact.push_back(d);↵
while (m%d == 0)↵
m /= d;↵
}↵
if (m > 1)↵
fact.push_back(m);↵
for (int root = 2; root <= n; ++root) {↵
bool found = true;↵
for (auto d : fact)↵
if (bin_pow(root, phi / d, n) == 1) {↵
found = false;↵
break;↵
}↵
if (found)↵
return root;↵
}↵
return n == 1 ? 1 : -1;↵
}↵
↵
```↵
</spoiler>↵
↵
↵
↵
5. Реализации↵
------------------↵
Будем использовать нерекурсивную реализацию для FFT (NTT) с дополнительной структурой для арифметики с редукцией Монтгомери/Баррета. Удобно использовать беззнаковые типы при работе с редукциями из-за трюка с переполнением: оно не ведёт к неопределённому поведению, а возвращает ответ по модулю $2^{64}$. ↵
↵
↵
<spoiler summary="NTT with Montgomery multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
↵
struct montgomery {↵
uint64_t n, nr;↵
↵
constexpr montgomery(uint64_t n) : n(n), nr(1) {↵
// log(2^64) = 6↵
for (int i = 0; i < 6; i++)↵
nr *= 2 - n * nr;↵
}↵
↵
[[nodiscard]]↵
uint64_t reduce(__uint128_t x) const {↵
uint64_t q = __uint128_t(x) * nr;↵
uint64_t m = ((__uint128_t) q * n) >> 64;↵
uint64_t res = (x >> 64) + n - m;↵
if (res >= n)↵
res -= n;↵
return res;↵
}↵
↵
[[nodiscard]]↵
uint64_t multiply(uint64_t x, uint64_t y) const {↵
return reduce((__uint128_t) x * y);↵
}↵
↵
[[nodiscard]]↵
uint64_t transform(uint64_t x) const {↵
return (__uint128_t(x) << 64) % n;↵
}↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
const uint64_t mod = 2524775926340780033, gen = 3;↵
//const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, montgomery& red, ↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, red.transform(1));↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = red.multiply(W[i-1], w);↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], inv_n);↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
montgomery red(mod);↵
for (auto& x : a)↵
x = red.transform(x);↵
for (auto& x : b)↵
x = red.transform(x);↵
↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
uint64_t inv_n = red.transform(bin_pow(n, mod-2, mod));↵
uint64_t root = red.transform(bin_pow(gen, (mod-1)/n, mod));↵
uint64_t inv_root = red.transform(bin_pow(red.reduce(root), mod-2, mod));↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, red, inv_n, root, inv_root, false);↵
ntt(b, rev, red, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], b[i]);↵
ntt(a, rev, red, inv_n, root, inv_root, true);↵
↵
for (auto& x : a)↵
x = red.reduce(x);↵
}↵
↵
```↵
</spoiler>↵
↵
↵
Реализация с редукцией Баррета может быть использована только с модулем $<2^{32}$ без 256-битных чисел. Т.к. редукция Баррета а
↵
<spoiler summary="NTT with Barret multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
↵
struct barret {↵
uint64_t n, s;↵
__uint128_t f;↵
↵
constexpr barret(uint64_t _n) {↵
n = _n;↵
s = 64;↵
f = (__uint128_t(1) << s) / n;↵
}↵
↵
[[nodiscard]]↵
uint64_t reduce(__uint128_t x) const {↵
auto t = (uint64_t)(x - ((x * f) >> s) * n);↵
if (t < n)↵
return t;↵
return t - n;↵
}↵
↵
[[nodiscard]]↵
uint64_t multiply(uint64_t x, uint64_t y) const {↵
return reduce((__uint128_t) x * y);↵
}↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
//const uint64_t mod = 2524775926340780033, gen = 3;↵
const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, barret& red, ↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = red.multiply(W[i-1], w);↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], inv_n);↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
barret red(mod);↵
uint64_t inv_n = bin_pow(n, mod-2, mod);↵
uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
uint64_t inv_root = bin_pow(root, mod-2, mod);↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, red, inv_n, root, inv_root, false);↵
ntt(b, rev, red, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = red.multiply(a[i], b[i]);↵
ntt(a, rev, red, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵
↵
↵
↵
Также здесь будет приведена реализация FFT со своей структурой для комплексных чисел. Она может быть использована с double / long double. ↵
↵
<spoiler summary="FFT with custom cmpls">↵
```c++↵
struct _cmpl {↵
double a, b;↵
_cmpl(double a = 0, double b = 0) : a(a), b(b) {}↵
↵
const _cmpl operator + (const _cmpl &c) const↵
{ return _cmpl(a + c.a, b + c.b); }↵
↵
const _cmpl operator - (const _cmpl &c) const↵
{ return _cmpl(a - c.a, b - c.b); }↵
↵
const _cmpl operator * (const _cmpl &c) const↵
{ return _cmpl(a * c.a - b * c.b, a * c.b + b * c.a); }↵
};↵
↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
↵
void fft(vector<_cmpl>& a, vector<int>& rev, bool invert) {↵
int n = a.size(), h = -1;↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
double alpha = 2 * atan2(0.00, -1.00) / n * (invert ? -1 : 1);↵
_cmpl w1(cos(alpha), sin(alpha));↵
vector<_cmpl> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = W[i - 1] * w1;↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
_cmpl t = a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))];↵
a[j ^ (1 << i)] = a[j] - t;↵
a[j] = a[j] + t;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = _cmpl(a[i].a / n, a[i].b / n);↵
}↵
↵
↵
void mul(vector<_cmpl>& a, vector<_cmpl>& b, vector<uint64_t>& res) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
auto rev = bit_sort(n);↵
fft(a, rev, false);↵
fft(b, rev, false);↵
for (int i = 0; i < n; i++)↵
a[i] = a[i] * b[i];↵
fft(a, rev, true);↵
res.resize(n);↵
for (int i = 0; i < n; i++)↵
res[i] = (uint64_t)(a[i].a + 0.1);↵
}↵
```↵
</spoiler>↵
↵
↵
NTT без оптимизаций: ↵
↵
<spoiler summary="NTT without opt">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /** n*m = 1 (mod p) => m = n**(p-2) (mod p) **/↵
uint64_t res = 1;↵
while (p) {↵
if (p & 1)↵
res = ((__uint128_t) res * n) % mod;↵
n = ((__uint128_t) n * n) % mod;↵
p >>= 1;↵
}↵
return res;↵
}↵
↵
vector<int> bit_sort(int n) {↵
int h = -1;↵
vector<int> rev(n, 0);↵
int skip = __lg(n) - 1;↵
for (int i = 1; i < n; ++i) {↵
if (!(i & (i - 1)))↵
++h;↵
rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
}↵
return rev;↵
}↵
↵
//const uint64_t mod = 998244353, gen = 3;↵
const uint64_t mod = 2524775926340780033, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev,↵
uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
int n = (int)a.size();↵
↵
for (int i = 0; i < n; ++i)↵
if (i < rev[i])↵
swap(a[i], a[rev[i]]);↵
↵
uint64_t w = invert ? inv_root : root;↵
vector<uint64_t> W(n >> 1, 1);↵
for (int i = 1; i < (n >> 1); ++i)↵
W[i] = (__uint128_t) W[i-1] * w % mod;↵
↵
int lim = __lg(n);↵
for (int i = 0; i < lim; ++i)↵
for (int j = 0; j < n; ++j)↵
if (!(j & (1 << i))) {↵
uint64_t t = (__uint128_t) a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))] % mod;↵
a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
}↵
↵
if (invert)↵
for (int i = 0; i < n; i++)↵
a[i] = (__uint128_t) a[i] * inv_n % mod;↵
}↵
↵
↵
void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
int n = 1;↵
while (n < a.size() || n < b.size())↵
n <<= 1;↵
n <<= 1;↵
a.resize(n);↵
b.resize(n);↵
↵
uint64_t inv_n = bin_pow(n, mod-2, mod);↵
uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
uint64_t inv_root = bin_pow(root, mod-2, mod);↵
auto rev = bit_sort(n);↵
↵
ntt(a, rev, inv_n, root, inv_root, false);↵
ntt(b, rev, inv_n, root, inv_root, false);↵
↵
for (int i = 0; i < n; i++)↵
a[i] = (__uint128_t) a[i] * b[i] % mod;↵
ntt(a, rev, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵
↵
↵
↵
6. Производительность↵
------------------↵
Мы будет тестировать 5 алгоритмов: FFT с double и long double, NTT без и с редукциями Монтгомери и Баррета. ↵
↵
Во-первых, посмотрим на [problem:993E]. Важно отметить, что NTT с редукцией Баррета и FFT с double получили wa43 (но прошли тесты на других тестах с большими входными данными), потому что FFT с double имеет плохую точность, а NTT с редукцией Баррета по модулю $998244353$ не может хранить большой ответ (NTT с Барретом и без uint256 может быть использовано только для небольших коэффициентов). ↵
↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>[submission:261258380], 311ms, BAD PRECISION</th>↵
<th>[submission:261256931], 702ms, OK</th>↵
<th>[submission:261256487], 281ms, OK</th>↵
<th>[submission:261257532], 312ms, BAD LIMITS</th>↵
<th>[submission:261257532], 859ms, OK</th>↵
</tr>↵
</table>↵
↵
Теперь протестируем перемножение многочленов размером $10^6$ со случайными коэффициентами на запуске Codeforces: ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>1407ms</th>↵
<th>2434ms</th>↵
<th>1374ms</th>↵
<th>1437ms</th>↵
<th>4122ms</th>↵
</tr>↵
<tr>↵
<th>1286ms</th>↵
<th>2426ms</th>↵
<th>1256ms</th>↵
<th>1675ms</th>↵
<th>3683ms</th>↵
</tr>↵
<tr>↵
<th>1052ms</th>↵
<th>2286ms</th>↵
<th>1153s</th>↵
<th>1320ms</th>↵
<th>3518ms</th>↵
</tr>↵
<tr>↵
<th>1079ms</th>↵
<th>2359ms</th>↵
<th>1309ms</th>↵
<th>1380ms</th>↵
<th>3961ms</th>↵
</tr>↵
<tr>↵
<th>1217ms</th>↵
<th>2791ms</th>↵
<th>1476ms</th>↵
<th>1780ms</th>↵
<th>3400ms</th>↵
</tr>↵
<tr>↵
<th>1241ms</th>↵
<th>2191ms</th>↵
<th>1125ms</th>↵
<th>1337ms</th>↵
<th>3512ms</th>↵
</tr>↵
</table>↵
↵
Теперь локальные тесты на Ryzen 5 5650u (AVX), 8Gb RAM 4266MHz, Debian 12, GCC 12.2.0. Параметры компиляции: `g++ -Wall -Wextra -Wconversion -static -DONLINE_JUDODGE -O2 -std=c++20 fftest.cc -o fftest` как на Codeforces. ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>691ms</th>↵
<th>2339ms</th>↵
<th>352ms</th>↵
<th>401ms</th>↵
<th>398ms</th>↵
</tr>↵
<tr>↵
<th>700ms</th>↵
<th>2328ms</th>↵
<th>411ms</th>↵
<th>500ms</th>↵
<th>487ms</th>↵
</tr>↵
<tr>↵
<th>729ms</th>↵
<th>2036ms</th>↵
<th>359ms</th>↵
<th>446ms</th>↵
<th>396ms</th>↵
</tr>↵
<tr>↵
<th>698ms</th>↵
<th>2164ms</th>↵
<th>380ms</th>↵
<th>456ms</th>↵
<th>472ms</th>↵
</tr>↵
<tr>↵
<th>734ms</th>↵
<th>2284ms</th>↵
<th>407ms</th>↵
<th>462ms</th>↵
<th>440ms</th>↵
</tr>↵
<tr>↵
<th>737ms</th>↵
<th>2148ms</th>↵
<th>396ms</th>↵
<th>459ms</th>↵
<th>434ms</th>↵
</tr>↵
</tr>↵
</table>↵
↵
На сервере Codeforces одинаковая производительность у Montgomery + NTT и FFT + doubles, затем идёт NTT + Barret, затем FFT + long doubles, затем NTT без оптимизаций. На локальной системе с современным процессором с поддержкой SIMD лидером является Montgomery + NTT. Причём NTT на локальной системе работает намного быстрее **даже без оптимизаций**. Я точно не могу определить, какие технологии отвечают за быструю 128-битную модулярную арифметику, но, скорее всего, это часть SIMD. Вообще, NTT хорошо поддаётся аппаратному ускорению, напирмер, <a href="https://link.springer.com/chapter/10.1007/978-3-030-78713-4_6">в статье</a> описано ускорение NTT на FPGA. ↵
↵
Как мы видим, производительность NTT по большому модулю сильно упирается в платформу из-за наличия большого числа операций с высокобитными числами, а именно деления с остатком 128-битных чисел. Это операция сильно упирается в аппаратную реализацию, поэтому из-за неё на Codeforces NTT без оптимизаций оказался самым медленным, но на CPU с ускорением SIMD он выигрывает у FFT. Тем не менее, разница может быть нивелирована засчёт явной реализации **редукции**, которая использует только умножение 128-битных чисел. Про проблемы 128-битного деления изложено <a href="https://danlark.org/2020/06/14/128-bit-division/">в статье</a>. ↵
↵
В итоге NTT + Montgomery является довольно универсальным выбором на всех платформах. Единственный нюанс в том, что NTT изначально не поддерживает отрицательные коэффициенты (т.к. всё берётся по модулю), но это можно исправить. ↵
↵
↵
7. Отрицательные коэффициенты↵
------------------↵
Рассмотрим $A(x)$ и $B(x)$ с любыми целыми коэффициентами, и мы хотим получить коэффициенты произведения через NTT. Пусть $C(x)$ — такой многочлен, что любой коэффициент у $A(x)+C(x)$, $B(x)+C(x)$ и $A(x)+B(x)+C(x)$ является неотрицательным. ↵
Заметим, что ↵
$$ A \cdot B = (A+C) \cdot (B+C) - C \cdot (A+B+C) $$↵
Так как каждое слагаемое в сумме имеет неотрицательные коэффициенты, $A(x) \cdot B(x)$ можно найти через 2 вызова умножения многочленов через NTT. $C(x)$ можно найти жадно за $O(n)$.↵
↵
↵
8. Класс для модулярной арифметики↵
------------------↵
Удобно создать класс для арифметики в $\mathbb{Z}/p\mathbb{Z}$ и какой-то редукцией. Ниже приведён класс с редукцией Баррета с методами для НОДа, степени, инверсий и корней (класс можно расширить до арифметики в любом конечном поле). P.S. Если мы хотим получить максимальную производительность и работаем с специфичным множеством чисел, можно использовать редукцию Монтгомери, потому что Монтгомери быстрее. ↵
↵
↵
<spoiler summary="Modular arithmetic class">↵
```c++↵
class zpz {↵
public:↵
static void init(uint32_t m) {↵
mod = m;↵
shift = 2*(32 - __builtin_clz(m));↵
factor = (uint64_t(1) << shift) / mod;↵
gen = 0;↵
}↵
↵
static uint32_t ext_gcd(uint32_t a, uint32_t b, uint64_t& x, uint64_t& y) {↵
if (a < b)↵
return ext_gcd(b, a, y, x);↵
if (b == 0) {↵
x = 1;↵
y = 0;↵
return a;↵
}↵
uint64_t x1, y1;↵
uint32_t g = ext_gcd(b, a%b, x1, y1);↵
x = y1;↵
y = x1 - (a/b)*y1;↵
return g;↵
}↵
↵
static zpz pow(zpz a, uint32_t n) {↵
zpz res = 1;↵
while (n) {↵
if (n & 1)↵
res *= a;↵
a *= a;↵
n >>= 1;↵
}↵
return res;↵
}↵
↵
static zpz inv(zpz a) {↵
if (inverses.find(a()) == inverses.end()) {↵
uint64_t x, y;↵
ext_gcd(a(), mod, x, y);↵
inverses[a()] = reduce(x + mod);↵
}↵
return inverses[a()];↵
}↵
↵
static zpz root(zpz a, int n) {↵
if (gen == 0) {↵
vector<uint32_t> fact;↵
// int phi = euler_totient(n);↵
uint32_t phi = mod-1;↵
uint32_t m = phi;↵
for (uint32_t d = 2; d*d <= m; ++d)↵
if (m%d == 0) {↵
fact.push_back(d);↵
while (m%d == 0)↵
m /= d;↵
}↵
if (m > 1)↵
fact.push_back(m);↵
for (uint32_t rt = 2; rt < mod; ++rt) {↵
bool found = true;↵
for (auto d : fact)↵
if (pow(zpz(rt), phi / d) == 1) {↵
found = false;↵
break;↵
}↵
if (found) {↵
gen = rt;↵
break;↵
}↵
}↵
gen = mod == 1 ? 1 : throw exception();↵
}↵
↵
return pow(zpz(gen), (mod-1)/n);↵
}↵
↵
static void get_all_modular_inverses() {↵
inverses[1] = 1;↵
for (int k = 2; k < mod; ++k)↵
inverses[k] = -1LL * (mod / k) * inverses[mod % k] % mod + mod;↵
}↵
↵
↵
zpz(uint32_t x) : val(reduce(x)) {}↵
↵
uint32_t operator () () const { return val; }↵
↵
zpz& operator =(uint64_t x) { val = reduce(x); return *this; }↵
↵
zpz& operator =(const zpz& x) { val = x(); return *this; }↵
↵
zpz& operator +=(const zpz& x) { val = reduce(val + x()); return *this; }↵
zpz& operator -=(const zpz& x) { val = reduce(val - x() + mod); return *this; }↵
zpz& operator *=(const zpz& x) { val = reduce((uint64_t) val * x()); return *this; }↵
↵
zpz& operator +=(uint64_t x) { return *this += zpz(x); }↵
zpz& operator -=(uint64_t x) { return *this -= zpz(x); }↵
zpz& operator *=(uint64_t x) { return *this *= zpz(x); }↵
↵
zpz& operator /=(const zpz& x) { return *this *= inv(x); }↵
zpz& operator /=(uint64_t x) { return *this /= zpz(x); }↵
zpz operator /(uint64_t x) { zpz cur = *this; return cur /= x; }↵
↵
zpz& operator ++() { return *this += 1; }↵
zpz& operator --() { return *this -= 1; }↵
↵
zpz operator ++(int unused) { zpz z(*this); ++(*this); return z; }↵
zpz operator --(int unused) { zpz z(*this); --(*this); return z; }↵
↵
friend zpz operator +(zpz x, const zpz& y) { return x += y; }↵
friend zpz operator *(zpz x, const zpz& y) { return x *= y; }↵
friend zpz operator -(zpz x, const zpz& y) { return x -= y; }↵
friend zpz operator /(zpz x, const zpz& y) { return x /= y; }↵
↵
friend zpz operator +(zpz x, uint32_t y) { return x += y; }↵
friend zpz operator *(zpz x, uint32_t y) { return x *= y; }↵
friend zpz operator -(zpz x, uint32_t y) { return x -= y; }↵
friend zpz operator /(zpz x, uint32_t y) { return x /= y; }↵
↵
friend zpz operator +(uint32_t x, zpz y) { return y += x; }↵
friend zpz operator *(uint32_t x, zpz y) { return y *= x; }↵
↵
friend zpz operator -(uint32_t x, const zpz& y) { zpz z(x); return z -= y; }↵
friend zpz operator /(uint32_t x, const zpz& y) { zpz z(x); return z /= y; }↵
↵
bool operator <(const zpz& x) const { return val < x(); }↵
bool operator ==(const zpz& x) const { return val == x(); }↵
bool operator >(const zpz& x) const { return val > x(); }↵
bool operator !=(const zpz& x) const { return val != x(); }↵
bool operator <=(const zpz& x) const { return val <= x(); }↵
bool operator >=(const zpz& x) const { return val >= x(); }↵
↵
bool operator <(uint32_t x) const { return val < x; }↵
bool operator ==(uint32_t x) const { return val == x; }↵
bool operator >(uint32_t x) const { return val > x; }↵
bool operator !=(uint32_t x) const { return val != x; }↵
bool operator <=(uint32_t x) const { return val <= x; }↵
bool operator >=(uint32_t x) const { return val >= x; }↵
↵
friend istream& operator >> (istream& input, zpz& x)↵
{↵
uint32_t z;↵
input >> z,↵
x = zpz(z);↵
return input;↵
}↵
↵
friend ostream& operator << (ostream& output, const zpz& x)↵
{↵
return output << x();↵
}↵
↵
↵
private:↵
static uint32_t mod, shift;↵
static uint64_t factor;↵
↵
static gp_hash_table<uint32_t, uint32_t> inverses;↵
static uint32_t gen;↵
↵
[[nodiscard]]↵
static uint32_t reduce(uint64_t x) {↵
auto t = (uint32_t)(x - (((__uint128_t) x * factor) >> shift) * mod);↵
if (t < mod)↵
return t;↵
return t - mod;↵
}↵
↵
uint32_t val;↵
};↵
↵
uint32_t zpz::mod, zpz::shift;↵
uint64_t zpz::factor;↵
gp_hash_table<uint32_t, uint32_t> zpz::inverses;↵
uint32_t zpz::gen;↵
```↵
</spoiler>↵
↵
↵
9. Заключение↵
------------------↵
В результате мы сравнили различные редукции для модульной арифметики и выяснили, что редукция Монтгомери имеет наилучшую производительность, если мы заранее знаем набор целых чисел. Следовательно, NTT с редукцией Монтгомери можно использовать вместо FFT в задачах с умножением многочленов (даже для полиномов с отрицательными коэффициентами). Здесь приведена реализация NTT с редукцией Монтгомери, которая показывает лучшую производительность даже по сравнению с FFT с double, особенно с аппаратной оптимизацией современных x64 процессоров (SIMD?). Остается открытым вопрос, какие именно аппаратные оптимизации позволяют NTT работать быстрее на современных процессорах "из коробки".