Hi Codeforces!
I have recently come up with a really neat and simple recursive algorithm for multiplying polynomials in $$$O(n \log n)$$$ time. It is so neat and simple that I think it might possibly revolutionize the way that fast polynomial multiplication is taught and coded. You don't need to know anything about FFT to understand and implement this algorithm.
I've split this blog up into two parts. The first part is intended for anyone to be able to read and understand. The second part is advanced and goes into a ton of interesting ideas and concepts related to this algorithm.
Prerequisite: Polynomial quotient and remainder, see Wiki article and Stackexchange example.
Task:
Given two polynomials $$$P$$$ and $$$Q$$$, an integer $$$n$$$ and a non-zero complex number $$$c$$$, where degree $$$P < n$$$ and degree $$$Q < n$$$. Your task is to calculate the polynomial $$$P(x) \, Q(x) \% (x^n - c)$$$ in $$$O(n \log n)$$$ time. You may assume that $$$n$$$ is a power of two.
Solution:
We can create a divide and conquer algorithm for $$$P(x) \, Q(x) \% (x^n - c)$$$ based on the difference of squares formula. Assuming $$$n$$$ is even, then $$$(x^n - c) = (x^{n/2} - \sqrt{c}) (x^{n/2} + \sqrt{c})$$$. The idea behind the algorithm is to calculate $$$P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$$$ and $$$P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$$$ using 2 recursive calls, and then use that result to calculate $$$P(x) \, Q(x) \% (x^n - c)$$$.
So how do we actually calculate $$$P(x) \, Q(x) \% (x^n - c)$$$ using $$$P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$$$ and $$$P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$$$?
Well, we can use the following formula:
$$$ \begin{aligned} A(x) \% (x^n - c) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\ &\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})). \end{aligned} $$$ Proof of the formulaNote that \begin{equation} A(x) = \frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) + \frac{1}{2} (1 — \frac{x^{n/2}}{\sqrt{c}}) A(x). \end{equation}
Let $$$Q^-(x)$$$ denote the quotient of $$$A(x)$$$ divided by $$$(x^n/2 - \sqrt{c})$$$ and let $$$Q^+(x)$$$ denote the quotient of $$$A(x)$$$ divided by $$$(x^n/2 + \sqrt{c})$$$. Then
$$$ \begin{aligned} (1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 + \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} - \sqrt{c})) + Q^-(x) (x^{n/2} - \sqrt{c})) \\ &= (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) + \frac{1}{\sqrt{c}} Q^-(x) (x^n - c)) \end{aligned} $$$and
$$$ \begin{aligned} (1 - \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 - \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} + \sqrt{c})) + Q^+(x) (x^{n/2} + \sqrt{c})) \\ &= (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) - \frac{1}{\sqrt{c}} Q^+(x) (x^n - c)). \end{aligned} $$$With this we have shown that
$$$ \begin{aligned} A(x) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\ &\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) \, + \\ &\frac{1}{\sqrt{c}} \frac{Q^-(x) - Q^+(x)}{2} (x^n - c). \end{aligned} $$$Here $$$A(x)$$$ is expressed as remainder + quotient times $$$(x^n - c)$$$. So we have proven the formula.
This formula is very useful. If we substitute $$$A(x)$$$ by $$$P(x) Q(x)$$$, then the formula tells us how to calculate $$$P(x) \, Q(x) \% (x^n - c)$$$ using $$$P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$$$ and $$$P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$$$ in linear time. With this we have the recipie for implementing a $$$O(n \log n)$$$ divide and conquer algorithm:
Input:
- Integer $$$n$$$ (power of 2),
- Non-zero complex number $$$c$$$,
- Two polynomials $$$P(x) \% (x^n - c)$$$ and $$$Q(x) \% (x^n - c)$$$.
Output:
- The polynomial $$$P(x) \, Q(x) \% (x^n - c)$$$.
Algorithm:
Step 1. (Base case) If $$$n = 1$$$, then return $$$P(0) \cdot Q(0)$$$. Otherwise:
Step 2. Starting from $$$P(x) \% (x^n - c)$$$ and $$$Q(x) \% (x^n - c)$$$, in $$$O(n)$$$ time calculate
$$$ \begin{align} P(x) \% (x^{n/2} - \sqrt{c}), \\ Q(x) \% (x^{n/2} - \sqrt{c}), \\ P(x) \% (x^{n/2} + \sqrt{c}), \\ Q(x) \% (x^{n/2} + \sqrt{c}). \end{align} $$$Step 3. Make two recursive calls to calculate $$$P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$$$ and $$$P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$$$.
Step 4. Using the formula, calculate $$$P(x) \, Q(x) \% (x^n - c)$$$ in $$$O(n)$$$ time. Return the result.
Here is a Python implementation following this recipie:
Python solution to the task"""
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time
Input:
n: Integer, needs to be power of 2
c: Non-zero complex floating point number
P: A list of length n representing a polynomial P(x)
Q: A list of length n representing a polynomial Q(x)
Output:
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)
"""
def fast_polymult_mod(P, Q, n, c):
assert len(P) == n and len(Q) == n
# Base case
if n == 1:
return [P[0] * Q[0]]
assert n % 2 == 0
import cmath
sqrtc = cmath.sqrt(c)
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))
# Q_minus := Q mod (x^(n/2) - sqrt(c))
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))
# Q_plus := Q mod (x^(n/2) + sqrt(c))
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c))
# PQ_plus := P * Q % (x^n/2 + sqrt(c))
PQ_minus = fast_polymult_mod(P_minus, Q_minus, n//2, sqrtc)
PQ_plus = fast_polymult_mod(P_plus, Q_plus, n//2, -sqrtc)
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +
[(m + p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]
return PQ
One final thing that I want to mention before going into the advanced section is that this algorithm can also be used to do fast unmodded polynomial multiplication, i.e. given polynomials $$$P(x)$$$ and $$$Q(x)$$$ calculate $$$P(x) \, Q(x)$$$. The trick is simply to pick $$$n$$$ large enough such that $$$P(x) \, Q(x) = P(x) \, Q(x) \% (x^n - c)$$$, and then use the exact same algorithm as before. $$$c$$$ can be arbitrarily picked (any non-zero complex number works).
Python implementation for general Fast polynomial multiplication"""
Calculates P(x) * Q(x)
Input:
P: A list representing a polynomial P(x)
Q: A list representing a polynomial Q(x)
Output:
A list representing the polynomial P(x) * Q(x) % (x^n - c)
"""
def fast_polymult(P, Q):
# Calculate length of the list representing P*Q
n1 = len(P)
n2 = len(Q)
res_len = n1 + n2 - 1
# Pick n sufficiently big
n = 1
while n < res_len:
n *= 2
# Pad with extra 0s to reach length n
P = P + [0] * (n - n1)
Q = Q + [0] * (n - n2)
# Pick non-zero c arbitrarily =)
c = 123.24
# Calculate P*Q mod x^n - c
PQ = fast_polymult_mod(P, Q, n, c)
# Remove extra 0 padding and return
return PQ[:res_len]
(Advanced) Speeding up the algorithm
This section will be about tricks that can be used to speed up the algorithm. This will in total speed it up by a factor of between 2 and 4.
$n$ doesn't actually need to be a power of 2We don't actually need the assumption that $$$n$$$ is a power of 2. If $$$n$$$ ever becomes odd during the recrsion, then we have two choices: Either fall back to a $$$O(n^2)$$$ algorithm or fall back to the unmodded $$$O(n \log{n})$$$ Polynomial multiplication algorithm.
Let us discuss the run time of falling back to the $$$O(n^2)$$$ algorithm when $$$n$$$ becomes odd. Assume that $$$n = a \cdot 2^b$$$, where $$$a$$$ is an odd integer and $$$b$$$ is an integer. Think of the recursive algorithm as having layers, one layer for each possible value of $$$n$$$. The first $$$b$$$ layers will all take $$$O(n)$$$ time each. In the $$$(b+1)$$$-th layer the value of $$$n$$$ is $$$a$$$. Using the $$$O(n^2)$$$ polynomial multiplication algorithm leads to this layer taking $$$O(n/a \cdot a^2) = O(n \cdot a)$$$ time. The final time complexity comes out to be $$$O((a + b) \, n)$$$.
Python implementation that works for both odd and even $n$"""
Calculates P(x) * Q(x) % (x^n - c) in O((a + b) * n) time, where n = a*2^b.
Input:
n: Integer
c: Non-zero complex floating point number
P: A list of length n representing a polynomial P(x)
Q: A list of length n representing a polynomial Q(x)
Output:
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)
"""
def fast_polymult_mod2(P, Q, n, c):
assert len(P) == n and len(Q) == n
# Base case (n is odd)
if n & 1:
# Calculate the answer in O(n^2) time
res1 = [0] * n
res2 = [0] * n
for i in range(n):
for j in range(n - i):
res1[i + j] += P[i] * Q[j]
for j in range(n - i, n):
res2[i + j - n] += P[i] * Q[j]
return [r1 + c * r2 for r1,r2 in zip(res1, res2)]
assert n % 2 == 0
import cmath
sqrtc = cmath.sqrt(c)
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))
# Q_minus := Q mod (x^(n/2) - sqrt(c))
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))
# Q_plus := Q mod (x^(n/2) + sqrt(c))
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c))
# PQ_plus := P * Q % (x^n/2 + sqrt(c))
PQ_minus = fast_polymult_mod(P_minus, Q_minus, n//2, sqrtc)
PQ_plus = fast_polymult_mod(P_plus, Q_plus, n//2, -sqrtc)
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +
[(m + p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]
return PQ
The reason why this is super useful is that it allows us to speed up the fast unmodded polynomial multiplication algorithm. As long as we are fine with $$$a$$$ being less than say $$$10$$$, then we might be able to choose a significantly smaller $$$n$$$ compared to what would be possible if we are only allowed to choose powers of two. This trick has the potential of making the fast unmodded polynomial multiplication algorithm run twice as fast.
Python implementation for more efficient fast unmodded polynomial multiplication"""
Calculates P(x) * Q(x)
Input:
P: A list representing a polynomial P(x)
Q: A list representing a polynomial Q(x)
Output:
A list representing the polynomial P(x) * Q(x) % (x^n - c)
"""
def fast_polymult2(P, Q):
# Calculate length of the list representing P*Q
n1 = len(P)
n2 = len(Q)
res_len = n1 + n2 - 1
# Pick n sufficiently big
b = 0
alim = 10
while alim * 2**b < res_len:
b += 1
a = (res_len - 1) // 2**b + 1
n = a * 2**b
# Pad with extra 0s to reach length n
P = P + [0] * (n - n1)
Q = Q + [0] * (n - n2)
# Pick non-zero c arbitrarily =)
c = 123.24
# Calculate P*Q mod x^n - c
PQ = fast_polymult_mod2(P, Q, n, c)
# Remove extra 0 padding and return
return PQ[:res_len]
Imaginary-cyclic convolutionSuppose that $$$P(x)$$$ and $$$Q(x)$$$ are two real polynomial, and that we want to calculate $$$P(x) \, Q(x)$$$. As discussed earlier, we can calculate the unmodded polynomial product by picking $$$n$$$ large enough such that $$$(P(x) \, Q(x)) \% (x^n - c) = P(x) \, Q(x)$$$ (here $$$c$$$ is any non-zero complex number), and then running the divide and conquer algorithm. But it turns out there is something smarter that we can do.
If we use $$$c = \text{i}$$$ (the imaginary unit) as the inital value of $$$c$$$, then this will allow us to pick an even smaller value for $$$n$$$. The reason for this is that if we get "overflow" from $$$n$$$ being too small, then that overflow will be placed into the imaginary part of the result $$$(P(x) \, Q(x)) \% (x^n - \text{i})$$$. This means that by using $$$c = \text{i}$$$ we are allowed to to pick $$$n$$$ as half the size compared to if we weren't using $$$c=\text{i}$$$. So this trick speeds the fast unmodded polynomial multiplication algorithm up by exactly a factor of 2.
Trick to go from $\% (x^n - c)$ to $\% (x^n - 1)$There is somewhat well known technique called "reweighting" that allows us to switch between working with $$$\% (x^n - c)$$$ and working with $$$\% (x^n - 1)$$$. I've previously written a blog explaining this technique, see here.
So why would we be interested in switching from $$$\% (x^n - c)$$$ to $$$\% (x^n - 1)$$$? The reason is that by using $$$c=1$$$, we don't need to bother with multiplying or dividing with $$$c$$$ or $$$\sqrt{c}$$$ anywhere, since $$$c=\sqrt{c}=1$$$. Additionally, if $$$c=-1$$$ or $$$c=\text{i}$$$ or $$$c=\text{-i}$$$, then multiplying or dividing by $$$c$$$ can be done very efficiently. So whenever $$$c$$$ becomes something other than $$$1,-1,\text{i}$$$ or $$$-\text{i}$$$, then it makes sense to use the reweight trick to switch back to $$$c=1$$$. Doing this could theoretically cut down the number of floating point operations in the algorithm by a lot. This is definitely something to consider if you want to create a heavily optimized polynomial multiplication algorithm.
This algorithm is actually FFT in disguise. But it is also different compared to any other FFT algorithm that I've seen before (for example the Cooley–Tukey FFT algorithm).
Using this algorithm to calculate FFTIn the tail of the recursion (i.e. when $$$n$$$ reaches 1), you are calculating $$$P(x) \, Q(x) \% (x - c)$$$, for some non-zero complex number $$$c$$$. This is infact the same thing as evaluating the polynomial $$$P(x) \, Q(x)$$$ at $$$x=c$$$. Furthermore, if you initially started with $$$c=1$$$, then the $$$c$$$ in the tail will be some $$$n$$$-th root of unity. If you analyze it more carefully, then you will see that each tail corresponds to a different $$$n$$$-th root of unity. So what the algorithm is actually doing is evaluating $$$P(x) \, Q(x)$$$ in all possible $$$n$$$-th roots of unity.
The $$$n$$$-th order FFT of a polynomial is defined as the polynomial evaluated in all $$$n$$$-th roots of unity. This means that the algorithm is infact an FFT algorithm. However, if you want to use it to calculate FFT, then make sure you order the $$$n$$$-th roots of unity according to the standard order used for FFT algorithms. The standard order is $$$\exp{(\frac{2 \pi \text{i}}{n} 0)}, \exp{(\frac{2 \pi \text{i}}{n} 1)}, ..., \exp{(\frac{2 \pi \text{i}}{n} (n-1))}$$$. To get the ordering correct, you will probably need to do a "bit reversal" at the end.
This algorithm is not the same algorithm as Cooley–Tukey (Advanced) Connection between this algorithm and NTT
Just like how there is FFT and NTT, there are two variants of this algorithm too. One using complex floating point numbers, and the other using modulo a prime (or more generally modulo an odd composite number).
Using modulo integers instead of complex numbersThis algorithm requires three properties. Firstly it needs to be possible to divide by $$$2$$$, and secondly $$$\sqrt{c}$$$ needs to exist, and thirdly it needs to be possible to divide by $$$\sqrt{c}$$$. This means that it is possible to extend the algorithm to work modulo a prime (or modulo an odd composite number) instead of using complex numbers.
What if $sqrt(c)$ doesn't exist?One of the things I dislike about NTT is that for NTT to be defined, there needs to exist a $$$n$$$-th root of unity. Usually problems involving NTT are designed so that this is never an issue. But if you want to use NTT where it hasn't been designed to magically work, then this is a really big issue. The NTT can become undefined!
Note that this algorithm does not exactly share the same drawback of being undefined. The reason for this is that if $$$\sqrt{c}$$$ doesn't exist, then the algorithm can simply choose to either switch over to a $$$O(n^2)$$$ polynomial multiplication algorithm, or fall back to an unmodded fast polynomial multiplication algorithm. The implications from this is that this algorithm can do fast polynomial multiplication even if it is given a relatively bad NTT prime. I just find this property to be really cool!
A good example of when NTT becomes undefined is this yosup problem convolution_mod_large. Here the NTT mod is 998244353. The tricky thing about the problem is that $$$n=2^{24}$$$. Since $$$998244353 = 119 \cdot 2^{23} + 1$$$ there wont exist any $$$n$$$-th root of unity, so the NTT of length $$$n$$$ is undefined. However, the divide and conquer algorithm can easily solve this problem by falling back to the $$$O(n^2)$$$ algorithm.