Text is also available on github pages (formatting may look a bit cleaner there)
Thanks alexxela12345 for reviewing this blog
Hello everyone!
I want to share some of my thoughts and experiments on vectorizing NTT.
I prefer NTT to real-valued FFT because of the imprecision of the latter. I use avx2 because it is the most advanced vector extension supported by the majority of modern online judges (including Codeforces), as of 2024. And I have no idea how many bugs (already found two) and how much UB my code contains, at least it passes some kind of tests.
Benchmark info
All the benchmarks are performed on my Ubuntu 22 Intel i5-1135g7 laptop. I execute cpupower frequency-set -d 3.0ghz -u 3.0ghz before benchmarks to (attempt to) fix CPU frequency for better accuracy. Code is available at github. You can try running all the benchmarks on your own machine with the run_all.sh script. Vertical dotted lines on plots mark sizes of L1, L2 and L3 caches in u32s.
The task
Given coefficients of two polynomials $$$A(x), B(x) \in \mathbb{F}_{\text{mod}}[x]$$$, compute the coefficients of $$$C(x) = A(x)B(x)$$$. Where $$$\text{mod} = 998\,244\,353$$$ (or other sufficiently good number).
Since it is accomplished by computing $$$A(x)B(x) \bmod (x^n - 1)$$$ for a big enough power of two $$$n$$$, we will focus on computing $$$A(x)B(x) \bmod (x^n - 1)$$$ for a given $$$n$$$. Such expression is also known as cyclic convolution.
Step A, standard implementation
One of the most common implementations unrolls recursion into three nested for loops with the help of bit-reverse permutation. This is what I consider the baseline.
The core of it looks like this
void transform(int lg, u32* data) {
for (int i = 0; i < (1 << lg); i++) {
if (bit_rev[lg][i] < i) {
std::swap(data[i], data[bit_rev[lg][i]]);
}
}
for (int k = 0; k < lg; k++) {
for (int i = 0; i < (1 << lg); i += (1 << k + 1)) {
for (int j = 0; j < (1 << k); j++) {
butterfly_x2(data[i + j], data[i + (1 << k) + j], w[k][j]);
}
}
}
}
Step A2, getting rid of bit-reversal
Applying bit-reverse permutation is somehow annoying because
- It is not useful work
- It is one of the worst memory access patterns (probably even worse than random)
- It is hard to speed up and vectorize
The simplest way to get rid of bit-reversal is just not doing it. Consider array elements as formal variables, permutation doesn't change their values, only their order in the array. We will perform the same operations on the same variables, but their positions in the array will be different.
void transform_forward(int lg, u32* data) {
for (int k = lg - 1; k >= 0; k--) {
for (int i = 0; i < (1 << lg); i += (1 << k + 1)) {
u32 wi = w[i >> k + 1];
for (int j = 0; j < (1 << k); j++) {
butterfly_x2(data[i + j], data[i + (1 << k) + j], wi);
}
}
}
}
With such an approach, the output of the transform function will be bit-reversed. It means nothing for pointwise product, but inverse transform has to be adjusted. It is no longer possible to efficiently express inverse transform using forward. But we can always invert any transform, just by inverting every operation performed in reverse order.
butterfly_x2 is just multiplication by an invertible 2x2 matrix:
Operations inside the two innermost loops are independent, so we need to reverse the order of the outermost loop only. I will not mention inverse transform for the next steps because it will be very similar to forward.
Step A3, optimizing initialization to just $$$O(\log n)$$$
Note that the new implementation loads the value of w[i] just $$$n - 1$$$ times, compared to $$$\frac{1}{2} n \log_2 n$$$ times in the standard implementation (we could swap the two innermost loops and get the same $$$n - 1$$$ times, but the memory access pattern would become awful).
It means that computing the value of w[i] on the fly (with one multiplication), instead of loading it from a precomputed array, will not result in a terrible performance decrease. So, to eliminate the need for an additional array of size $$$n$$$, we will use the value of w[i], a precomputed array of size $$$\log_2 n$$$ and one multiplication to compute the value of w[i + 1].
To achieve that, we need to know what the entries of array w are. Let $$$g$$$ be the primitive root we are using. Let
Let $$$F(s)$$$ denote the set of indices of all nonzero bits in $$$s$$$ (counting from $$$0$$$).
Then w[i] is equal to
It's not hard to see that the quotient w[i + 1] / w[i] depends only on the number of trailing ones in the binary representation of i. We can precompute these quotients for every number of trailing ones. Then on every iteration of the middle loop we will:
- Compute the number of trailing ones in
i(with the help of thetzcntinstruction) - Multiply the previous value of
wby the corresponding array element.
I think that having negligible initialization time and additional memory usage is cool enough to justify a mild performance decrease.
Initialization as-is works in $$$O(\log^2 n)$$$, but it is still negligible.
Step B, utilizing Montgomery reduction
So far we have relied on compiler-generated (for known in compile-time modulus) Barrett reduction. To vectorize modular arithmetic we need to know how it works, so on this step we will implement manual handling of all modular arithmetic. I will use Montgomery reduction, because vectorized version of it performed better than vectorized version of any other reduction algorithm I tried (though there aren't many of them).
But Montgomery reduction has a problem. It divides by $$$2^{32}$$$ on each reduction. There is a well-known solution to this problem called Montgomery space. But moving all array entries in and out of space will take too much time.
So we will do something about it. There are four places where we use multiplication:
- In the
butterfly_x2function - For scaling the result by $$$\frac{1}{n}$$$
- For pointwise multiplication
- For precomputing twiddle factors
I will call the map $$$ \mathbb{F}_{\text{mod}} \times \mathbb{F}_{\text{mod}} \to \mathbb{F}_{\text{mod}} \quad a, b \mapsto ab \cdot 2^{-32} $$$ Montgomery multiplication. I will say that a variable $$$x$$$ is in Montgomery space if the actual value stored is $$$x \cdot 2^{32}$$$.
If we multiply a usual number by a number from Montgomery space we get a usual number. We multiply by twiddle factors in butterfly_x2 only, so we will compute twiddle factors in Montgomery space and leave array entries as usual numbers. The butterfly_x2 effect won't change. This solves the problem of twiddle factor precomputation as well.
We can scale the result using Montgomery multiplication, but we need to account for the Montgomery reduction factor. It can be done in $$$O(1)$$$ additional work, just by multiplying the scaling constant by some precomputed factor.
Pointwise multiplication is a bit tricky, but we can cheat a little. If we use Montgomery multiplication for the pointwise product $$$a \cdot b$$$ while keeping $$$a$$$ and $$$b$$$ as usual numbers, the result will be off by a factor of $$$2^{-32}$$$. But we can similarly cancel that factor during the scaling step in $$$O(1)$$$ addition work. Though this approach won't work for non-homogeneous polynomials, like $$$2a - a^2b$$$ (this particular one is used in the computation of inverse power series).
Arithmetic usage optimization
Moduli are typically 30-bit wide, and we can abuse that. Instead of always having all numbers in $$$[0, \text{mod})$$$, we will allow them to be in $$$[0, 2 \cdot \text{mod})$$$ or $$$[0, 4 \cdot \text{mod})$$$ and apply shrink when necessary (previously we would perform shrink immediately after addition/subtraction). For 30-bit moduli, Montgomery reduction can reduce from $$$[0, 4 \cdot \text{mod}^2) \subset [0, 2^{32} \cdot \text{mod}) $$$ to $$$[0, 2 \cdot \text{mod})$$$. This allows us to reduce the number of shrinks (though we should be careful about it, it's very easy to place a bug while counting how many shrinks have to be applied)
u32 shrink(u32 val) const {
return std::min(val, val - mod);
}
This particular implementation of shrink may not be the most efficient one for the scalar case, but I guess it is the most efficient one for the vector case, since both u32 subtraction and minimum are natively supported by avx2 (_mm256_sub_epi32 and _mm256_min_epu32 intrinsics).
We may get rid of two layers worth of multiplications, by noticing that all values of twiddle factors on the topmost layer are ones, so multiplying by them is trivial (not required at all). So is half of the twiddle factors on the next layer, a quarter on the layer after next, ..., and they add up to $$$1 + \frac{1}{2} + \frac{1}{4} + ... = 2 - \frac{2}{n} \approx 2$$$ layers.
Because the code is getting bloated by various optimizations, I will pack the innermost loop of the transform function to a template parametrized function transform_aux to shorten the code.
Now we can use non-constexpr modulus. But there are some peculiarities. I don't really know why, but if we don't save the struct for Montgomery to a local variable (like this const Montgomery mt = this->mt;), the compiler will generate unnecessary load instructions for Montgomery constants in hot loops.
Step C, vectorization
Vectorizing multiplication
Vectorizing multiplication is the most crucial step, since the majority of the performance improvement comes directly from it. Yet I won't be comprehensive and will just show the best I could achieve, without explaining how and why. I'll probably write a separate article later.
n_inv and mod are u32x8s filled with corresponding values from scalar Montgomery. mul_u32x8 computes the pointwise Montgomery product of input vectors.
u32x8 reduce(u64x4 x0246, u64x4 x1357) const {
u64x4 x0246_ninv = _mm256_mul_epu32(x0246, n_inv);
u64x4 x1357_ninv = _mm256_mul_epu32(x1357, n_inv);
u64x4 x0246_res = _mm256_add_epi64(x0246, _mm256_mul_epu32(x0246_ninv, mod));
u64x4 x1357_res = _mm256_add_epi64(x1357, _mm256_mul_epu32(x1357_ninv, mod));
u32x8 res = _mm256_or_si256(_mm256_bsrli_epi128(x0246_res, 4), x1357_res);
return res;
}
u32x8 mul_u32x8(u32x8 a, u32x8 b) const {
u32x8 a_sh = _mm256_bsrli_epi128(a, 4);
u32x8 b_sh = _mm256_bsrli_epi128(b, 4);
u64x4 x0246 = _mm256_mul_epu32(a, b);
u64x4 x1357 = _mm256_mul_epu32(a_sh, b_sh);
return reduce(x0246, x1357);
}
(type casts are omitted)
It uses just 12 instructions, six of which are multiplications, with the longest dependency chain having a latency of 18 cycles, with 3 multiplications (latency 5) and 3 other instructions (latency 1) on it.
Note: it is also possible to implement mul_u32x8 with Barrett reduction, but all of my attempts were slower by at least 20-30%. It may be reasonable to use Barrett reduction for the pointwise product part when the Montgomery reduction factor can't be easily removed.
Back to our algorithm
I will call outer loop iterations in transform function for k >= 3 top layers, and all the other outer loop iterations (for k < 3) bottom layers.
Using our mul_u32x8 in top layers is trivial because we operate on consecutive segments of data of at least 8 u32s (register size). But for bottom layers things get complicated, we now have to deal with in-register shuffles. So top and bottom layers are now split into separate loops.
Let's implement butterfly_x2 in a way friendly to vector operations. We are performing the transform on values a,b stored in vector [a, b].
- Multiply
[a, b]by vector[1, w]pointwise:-> [a, b * w] - Create a copy of
[a, b * w]whereaandb * ware swapped:-> [b * w, a](in-register shuffle is used) - Negate
b * win[a, b * w](by negating all and blending):-> [a, -b * w] - Add vectors
[a, -b * w]and[b * w, a]pointwise:-> [a + b * w, a - b * w]
We use this approach to simultaneously perform four butterfly_x2 on 8 consecutive elements.
For k = 2 our u32x8 will look like [a1, a2, a3, a4, b1, b2, b3, b4].
For k = 1 our u32x8 will look like [a1, a2, b1, b2, a3, a4, b3, b4].
For k = 0 our u32x8 will look like [a1, b1, a2, b2, a3, b3, a4, b4].
(we are applying butterfly_x2 to pairs (a_i, b_i))
We won't separate loop iterations for k = 0, 1, 2, and will just apply all of them at once to each block of 8 consecutive u32s. With an approach like this, we can easily vectorize recalculation of twiddle factors by packing all $$$1 + 2 + 4 = 7$$$ of them to a single u32x8 and updating simultaneously with one mul_u32x8. The latter is quite important since the amount of twiddle factor recalculation grows exponentially with layer depth, and by vectoring it at the bottom layers, we vectorize about $$$\approx \frac{1}{2} + \frac{1}{4} + \frac{1}{8} = 87.5\%$$$ of all recalculations.
Step D, radix4 butterfly
We made the computational part of our algorithm 5-10x faster, but I/O can't keep up, and our algorithm gets bottlenecked by memory bandwidth for large arrays ($$$n \ge 2^{20}$$$), even though the memory access pattern is linear. Now we are going to reduce this slowdown.
One of the possible solutions is to perform two layers simultaneously (with radix4 butterfly), this halves the number of memory scans and reduces the speed of scanning twofold (since we are now doing two times more computation per byte loaded).
We implement butterfly_x4 by simply stacking four butterfly_x2 onto each other.
Since the number of top layers may or may not be odd, we need to add one conditional butterfly_x2 layer. I choose the topmost layer because it is simplest to implement (all twiddle factors are ones).
Another benefit of radix4 butterfly is the ability to (easily) vectorize the recalculation of twiddle factors. Now each butterfly_x4 requires three different twiddle factors, so we can pack them into a single u64x4 and update simultaneously (like we did in the bottom layers). If not for that, there would be almost no performance improvement (compared to radix2) for arrays fitting in L1 or L2 cache (at least on my machine).
Step E, optimizing bottom layers
Bottom layers are really slow, 3 bottom layers take as much time as 10 top layers (probably because we didn't vectorize them properly). (Back in November 2023) I had been wondering how I could make them faster for quite a while when I found this blog by pajenegod. It clearly shows what exactly the code we got at step A2 is computing. Moreover, it suggests switching to $$$O(n^2)$$$ multiplication when we are running out of square roots. But there may be another reason for switching to the $$$O(n^2)$$$ algorithm, it can simply be faster than the $$$O(n \log n)$$$ algorithm for small values of $$$n$$$. And this is exactly the case.
Before bottom layers, we already have computed $$$A(x) \bmod (x^8 - w_i)$$$ and $$$B(x) \bmod (x^8 - w_i)$$$ for every $$$i$$$ from $$$0$$$ to $$$\frac{n}{8}$$$. But now, instead of going deeper, computing values of $$$A(x), B(x)$$$ at the roots of $$$x^8 - w_i$$$, multiplying them pointwise and retrieving $$$A(x)B(x) \bmod (x^8 - w_i)$$$ from its values at those roots, we will multiply $$$A(x) \bmod (x^8 - w_i)$$$ by $$$B(x) \bmod (x^8 - w_i)$$$ modulo $$$(x^8 - w_i)$$$ straight away.
Because we no longer need twiddle factors for bottom layers, we can perform convolution on larger arrays. For $$$\text{mod} = 998\,244\,353 = 2^{23} \cdot 7 \cdot 17 + 1$$$, the limit was $$$n = 2^{23}$$$, but now we can use $$$n = 2^{26}$$$. And for $$$\text{mod} = 469762049 = 2^{26} \cdot 7 + 1$$$ (this one has the largest power of two among all 30-bit primes), we can use $$$n = 2^{30}$$$ (though the code as-is would break, because int is not enough for indices that large).
But switching to $$$O(n^2)$$$ multiplication at bottom layers has downsides. If we need to perform heavy computation with the output of NTT, we will have to perform operations $$$\bmod (x^8 - w_i)$$$ instead of doing them just pointwise.
Step F, recursive computation order
Now we are going to improve cache optimality even further. After finishing the topmost radix4 layer, we get four independent parts, so we can perform computation for each part recursively. With recursive order, only several topmost layers will be affected by memory bandwidth slowdown.
We can try explicit recursion, it won't introduce much overhead because we only need it for several topmost layers. We can try unrolling recursion into a for loop by storing stack state in a bitmask. But there is another, simpler and more efficient way to make computation (fully) recursive without much overhead. I first saw the idea in this submission.
We will examine the radix2 case first, generalizing the approach to radix4 will be very simple. Let's visualize computational order as a tree, where nodes are recursive calls.
During each recursive call, exactly one aux_transform is performed, and it is performed on half of the subarray of the parent node. The previous order executed aux_transform first by depth in the tree, then from left to right. Now we want to do it in recursive order.
We may notice a pattern: if we ascend from a leaf node using 0 edges only, we will traverse a consecutive segment of recursive calls (these paths are marked with color). So let's iterate over the leaf nodes, ascend until we meet a 1 edge, and perform the corresponding transforms in the nodes of the traversed path (they should be performed from top to bottom). The length of the path can be calculated with the help of the tzcnt instruction -- it is exactly the number of trailing zeros in the binary representation of the leaf index. The order of aux_transform calls is still consecutive on each layer, so twiddle factors will be updated correctly. The leftmost path will be a special case since all aux_transform calls on it will have twiddle factors of one.
Inverse transform can be made fully recursive similarly, but we ascend by 1 edges instead of 0 edges.
Note: I don't really know if radix8 is an option, there might be issues with it, such as lack of logical registers (avx2 has only 16), enormous (machine) code size (~220 instructions per butterfly, don't really know if it matters) and the (counterintuitive) fact that it (and higher radix transforms) might be less I/O optimal because of the way cache associativity works.
Final result
Now our convolution is 9-10 times faster than the original one. There are still things to improve, but doing so is rather complicated and won't give much of an improvement.
We can submit it to this problem and get very close to the fastest submission. (we need to steal a fast I/O template from the fastest submission for a fair comparison)
Execution time measured by the system includes time for reading input data and printing output data. And even with custom fast I/O, it takes several times more than the convolution itself. So, to measure actual computation time more accurately, one needs to do it by himself and print the result to stderr (luckily the judge shows stderr on every test).
Our submission uses ~7.0ms for actual computation (of cyclic convolution of size $$$2^{20}$$$). The author of the fastest submission (as of 17 Aug 2024) also printed actual computation time to stderr, his submission uses ~6.5ms. And this submission (by the same author) uses just ~6.05ms, though it doesn't have fast I/O and runs in more than 100ms in total.
Thank you for reading!







