四年前写过一篇介绍 FFT 的博客,但其中漏洞百出,应该坑了不少人,故现在写一份新的,希望对算法竞赛选手有帮助。我们先回顾复数的一些知识,再介绍多项式的点值表达,最后给出 FFT 的详细过程和代码实现。
复数的概念
我们每个人都很熟悉一元二次方程 $ax^2+bx+c = 0$. 它有个判别式 $\Delta = b ^ 2 - 4ac$,如果 $\Delta \geq 0$,方程就有两个互不相等的实根;如果 $\Delta = 0$,则方程有一个实根(或者说,两个相等的实根);如果 $\Delta < 0$,则方程没有实根。经过一番推导,我们可以给出方程的解:$$x _ 1 = \frac{-b + \sqrt{\Delta}}{2a}, \quad x _ 2 = \frac{-b - \sqrt{\Delta}}{2a}$$
显然,如果 $\Delta < 0$,上面的 $x _ 1, x_2$ 都是无意义的。例如 $x^2 + x + 1$ 的判别式 $\Delta = -3$,于是求不出 $x _ 1, x_ 2$.
我们把 $ x _ 1, x _ 2$ 相加、相乘,于是可以得到韦达定理:$$x _ 1 + x _ 2 = -\frac{b}{a}, \quad x _1 \cdot x _ 2 = \frac{c}{a}$$
韦达定理的式子里面并没有根号!拿着韦达定理去套 $x ^ 2 + x + 1$,我们会获得 $$x _ 1 + x_ 2 = -1,\quad x _1 \cdot x _ 2 = 1$$
但 $x _ 1, x _ 2$ 在实数域上是不存在的,它们怎么会有实数的和、积呢?现在有两条路:要么限定韦达定理不准用于 $\Delta > 0$ 的情况,要么承认 $x_ 1, x _ 2$ 都是存在的,只是不存在于实数域上。
前者是初中生干的事,我们选择后一种。引入一个数 $i$,它满足性质 $ i ^ 2 = -1$,称之为“虚数单位”。相应地,借助虚数单位,立刻可以把实数域 $R$ 扩展到复数域 $C$:$$C: \left\{ a + b\cdot i ~ | ~ a, b \in R\right\}$$
那么,对于 $x ^ 2 + x + 1$,就有了两个复根。把实部作为横轴坐标,虚部作为纵轴坐标,如图:
复数的运算
复数相对于实数,无非引入了虚数单位;其余的运算法则没有什么影响。显然
- $ (a + bi) + (c + di) = (a+c) + (b+d)i$
- $ (a+bi)(c+di) = (ac - bd) + (ad+bc)i$
另外,如果把坐标系改为以 $x$ 轴正方向为极轴的极坐标,那么复数有另一种表达方式——模长、辐角的二元组。模长就是极坐标系上的极径,辐角就是极角。
上图展示了两个复数在直角坐标系、极坐标系上的情形。其中,$x_1 = -\frac12 - \frac {\sqrt{3}}{2}i$ 的极坐标形式是 $(1, \frac{2\pi}{3})$;$x _ 2$ 则是 $(1, \frac{4\pi}{3})$.
经过繁杂的运算,不难验证,复数的乘法也可以表述为“模长相乘、辐角相加”。例如 $$x_1 \cdot x _2 = (1\cdot 1, \frac{2\pi}3 + \frac{4\pi}{3}) = (1, 2\pi) = (1, 0)$$
也就是 $x_ 1 \cdot x_2 = 1$,这与韦达定理算出 $x_1 \cdot x _2 = \frac{c}{a}=1$ 是吻合的。
欧拉公式
以下这个著名的公式,被称为欧拉公式:$$e ^ {i \theta} = \cos \theta + i\sin \theta$$
它是怎么来的?还记得 $\sin(x), \cos(x)$ 的泰勒展开吗?$$\begin{aligned}\sin(x) &= x - \frac{x ^ 3}{3!} + \frac{x ^ 5}{5!} - \frac{x ^ 7}{7!} + \frac{x ^ 9}{9!} \cdots \\ \cos(x) &= 1 - \frac{x ^ 2}{2!} + \frac{x ^ 4}{4!} - \frac{x ^ 6}{6!} + \frac{x ^ 8}{8!}\cdots\end{aligned}$$
好的,现在写出 $\cos \theta + i\sin \theta$ 的展开式:$$1 + i\theta - \frac{\theta ^ 2}{2!} - i\frac{\theta ^ 3}{3!} + \frac{\theta ^ 4}{4!} + i\frac{\theta ^ 5}{5!} \cdots$$
立刻发现这和 $e ^ x$ 的展开非常相似。我们有$$e ^ x = 1 + x + \frac{x ^ 2}{2!} + \frac{x ^ 3}{3!} + \frac{x ^ 4}{4!} + \frac{x ^ 5}{5!}\cdots$$
现在代入 $x = i\theta$,来看式子变成什么样吧:$$\begin{aligned}e ^ {i\theta} &= 1 + i\theta + \frac{i ^ 2 \theta ^ 2}{2!}+ \frac{i ^ 3 \theta ^ 3}{3!}+ \frac{i ^ 4 \theta ^ 4}{4!}+ \frac{i ^ 5 \theta ^ 5}{5!} \\ &=1 + i\theta - \frac{\theta ^ 2}{2!} - i\frac{\theta ^ 3}{3!} + \frac{\theta ^ 4}{4!} + i\frac{\theta ^ 5}{5!} \cdots \end{aligned}$$
注意到 $e ^ {i\theta}$ 的展开式与 $\cos \theta + i\sin \theta$ 的一模一样,于是我们有了欧拉公式 $$e ^ {i\theta} = \cos\theta + i\sin\theta$$
进一步地,欧拉公式还能给出一个推论:$e ^ {\pi i} = -1$.
复平面上的单位圆
直角坐标系上的单位圆,上面的点可以表示成 $(\cos \theta, \sin \theta)$,每一个点与一个角度一一对应。复平面上也有“单位圆”,也就是模长为 1 的全体复数的集,可以记为 $\cos \theta + i\sin\theta$,其中 $\theta$ 是辐角。立刻注意到,这个单位圆上的点 $\cos \theta + i\sin\theta$,也可以写成 $e ^ {i\theta}$. 例如我们上文举过的例子 $x_1, x _2$ 都在单位元上,有$$x _ 1 = -\frac12 - \frac {\sqrt{3}}{2}i = (1, \frac{2\pi}{3}) = e ^ {\frac{2\pi}{3}i}$$
接下来,我们在 FFT 中,要讨论的求值点都是单位圆上的点。
多项式的点值表达
一个 $n$ 次多项式,是形如 $A(x) = a_0 + a_1 x + a_2 x ^ 2 +\cdots + a_n x ^n$ 的函数。其中的 $n$ 称为这个多项式的阶;显然多项式是 $R \to R$ 的一个映射。
接下来我们要给出一个结论:
给定 $n+1$ 个互不相同的点,可以拟合出一条 $n$ 阶的多项式经过这些点。这个过程称为“插值”。
例如:两个点可以确定一条直线,三个点可以确定一条抛物线。
最朴素的插值算法(也是最适合手算的方法)是高斯消元。对于所有的点列出方程,最后解这个 $n+1$ 元一次方程组,即可得到这个多项式的全部 $n+1$ 个系数。复杂度是 $O(n ^ 3)$.
存在效率更高的算法,例如拉格朗日插值法是 $O(n ^ 2)$ 的时间复杂度。原理很简单,本博客不再赘述。现在我们利用拉格朗日插值法找到一条经过 $(1,3), (2,4), (3, -1), (4, 2)$ 的多项式:
多项式乘法与 FFT
朴素的多项式乘法,需要 $O(n ^ 2)$ 的复杂度。分块乘法有一点提升,大概是 $O(n ^ {1.59})$ 的复杂度。
现在我们手上有两个 $n$ 阶多项式 $A, B$,该如何快速求出它们的乘积呢?考虑下面的算法:
- 指定至少 $2n+1$ 个 x 轴坐标点 $x_i$,求出 $A, B$ 在这些点上的值。
- 记多项式 $C = A \cdot B$,不难发现 $C(x_i) = A(x _ i) \cdot B(x _ i)$.
- 于是我们把 $C( x _i)$ 的值全都求出来,由这些点插值得到 $C$ 的系数,就是 $A\cdot B$ 的系数了。
下面给出一个例子。计算 $A = 1+2x+3x^2$ 与 $B=3+x+x ^ 2$ 的乘积:
多项式的单点求值过程是 $O(n)$ 的,采用霍纳法则(或称秦九韶算法)实现。一共要求出 $n+1$ 个点的值,总代价是 $O(n ^ 2)$;点乘过程是 $O(n)$ 的;插值得到 $C$ 的过程是 $O(n^2)$的。总复杂度仍然是 $O(n ^ 2)$,相当于没有改进。
瓶颈是在求值、插值上。如果我们能提出一个快速的求值、插值算法,就能快速完成多项式乘法。这些 $2n+1$ 个求值点是可以由我们算法决定的。
快速傅里叶变换(FFT)的原理就是:我们选择一组非常特殊的求值点,然后利用各种性质简化运算,最终在 $O(n \log n)$ 的时间复杂度内完成求值和插值。
单位根
回到那个复平面上的圆。我们把这个圆 $n$ 等分,其中固定一个点在 $(1, 0)$ 处。于是这个圆上有了 $n$ 个等分点,称为 $n$ 次单位根。显然,$n$ 次单位根的 $n$ 次方就等于 1. 下面展示了 8 次、4 次单位根:
我们把 $n$ 次单位根的第 $k$ 个记为 $\omega_n ^ k$. 显然有 $$\omega_n ^ k = e ^ {\frac{k}{n}2\pi i}$$
用上面的式子,立刻可以得到几条性质。
- 消去引理:$\omega _ {dn} ^ {dk} = \omega_n ^ k$
- 折半引理:$2n$ 次单位根的平方的集合,就是 $n$ 次单位根的集合。
证明:对于 $(\omega_{2n} ^ {k}) ^ 2$,其中 $k < n$ 的情况有 $(\omega_{2n} ^ {k}) ^ 2 = \omega_{2n} ^ {2k} = \omega _n ^k$. 对于其他情况:
$$(\omega_{2n} ^ {k + n})^2 = \omega_{2n}^{2k+2n} = \omega _ {2n} ^ {2k} \cdot \omega_{2n}^{2n} = \omega_{2n} ^ {2k} = \omega_n ^ k$$
取 $n$ 次单位根为求值点,对多项式求值,这个过程称为“离散傅里叶变换(DFT)”。逆过程(插值)称为“离散傅里叶反变换(IDFT)”。FFT 算法是 DFT、IDFT 的快速实现,折半引理是 FFT 分治的基础。
分治求值
现在我们的目标是快速完成 DFT,也就是求出多项式 $A$ 在 $n$ 次单位根上的值。按传统方法肯定得 $O(n^2)$,但这一次求值点有特殊性质。我们考虑多项式$$A = a_0 + a_1x + a_2x ^2 + a_3 x ^ 3 + a_4 x ^ 4 +\cdots$$拆成两个多项式:$$\begin{aligned} A(x) &= A_ 0 (x^2) + x\cdot A _1 (x ^ 2) \\ A _ 0 &= a _0 + a _ 2 x + a _ 4 x ^ 2 + \cdots \\ A _ 1 &= a_1 + a _3 x + a _5 x ^ 2+\cdots \end{aligned}$$
注意到 $A_0, A _ 1$ 的阶都只有 $A$ 的一半,而且有另一个性质:虽然一共有 $n$ 个 $A(x)$ 需要求值,但 $x ^ 2$ 一共只有 $n /2$ 种取值,于是我们只需要对 $n/2$ 个数求 $A_0$ 的值、$n/2$ 个数求 $A_1$ 的值!而这些值可以直接交由递归计算。
具体来讲,我们先递归求出 $A_0 (\omega _{n/2} ^ k), A_1 (\omega _{n/2} ^ k)$ 的值,其中 $k < n/2$;再枚举 $n/2$ 以内的 $k$:
- 求 $A\left(\omega _ n ^ k\right)$ $$\begin{aligned}A(\omega _ n ^ k) &= A_0 \left((\omega _n ^k) ^ 2\right ) + \omega _ n ^ k \cdot A_1 \left((\omega _n ^k) ^ 2\right ) \\ &= A_0 \left(\omega _{n/2} ^k \right ) + \omega _ n ^ k \cdot A_1 \left(\omega _{n/2} ^k \right )\end{aligned}$$
- 求 $A\left(\omega _ n ^ {k + n/2}\right)$ $$\begin{aligned}A\left(\omega _ n ^ {k + n/2}\right) &= A_0 \left((\omega _n ^{k + n/2}) ^ 2\right ) + \omega _ n ^ {k + n/2} \cdot A_1 \left((\omega _n ^{k + n/2}) ^ 2\right ) \\ &= A_0 \left(\omega _{n/2} ^k \right ) + \omega _ n ^ {k + n/2} \cdot A_1 \left(\omega _{n/2} ^k \right ) \\ &= A_0 \left(\omega _{n/2} ^k \right ) - \omega _ n ^ {k} \cdot A_1 \left(\omega _{n/2} ^k \right )\end{aligned}$$
于是每轮循环可以求出两个 $A$ 值,我们只需要枚举 $n/2$ 个 $k$,就可以求出 $A$ 的 DFT!
来计算时间复杂度。对于阶为 $n$ 的多项式求 $n$ 个点的 DFT,把大问题拆分成子问题的代价是 $O(n)$;递归为两个子问题,每个子问题需要求 $(n/2)$ 阶多项式的 $(n/2)$ 个点的 DFT,问题规模是 $(n/2)$;合并子问题的解得到大问题的解,代价是 $O(n)$. 于是复杂度满足递推式:$$T(n) = 2T(n/2) + O(n)$$
解得 $T(n) = O(n \log n)$. 我们实现了快速的 DFT 算法。
逆变换
现在的问题是,如何在 $O(n\log n)$ 的时间复杂度内,完成插值。我们不难注意到,以上对于多项式 $A$ 求值的过程,实际上可以视为下面的矩阵乘法:
$$\begin{bmatrix}1 & 1 & 1 & \cdots & 1 \\1 & x_1 & x_1^2 & \cdots & x_1 ^ n\\1 & x_2 & x_2^2 & \cdots & x_2 ^ n \\ \vdots & \vdots &\vdots & \ddots & \vdots \\ 1 & x_{n-1} & x_{n-1}^2 & \cdots & x_{n-1} ^ n\end{bmatrix} \cdot \begin{bmatrix}a_0 \\ a _ 1 \\ a _ 2 \\ \vdots \\ a _ n\\ \end{bmatrix} = \begin{bmatrix}y_0 \\ y _ 1 \\ y _ 2 \\ \vdots \\ y _ n\\ \end{bmatrix}$$
其中 $x _ k$ 是 $\omega _ n ^ k$. 在 IDFT 中,我们的任务是:已知向量 $\boldsymbol{Y}$,需要推出向量 $\boldsymbol{A}$. 这只需要把左边矩阵的逆矩阵乘以 $\boldsymbol{Y}$ 就能得到。注意到左边的矩阵是一个范德蒙德矩阵,由于单位复数根的性质,逆矩阵形式非常漂亮,具体过程请看下面的文章,我实在懒得打字了orz:
体现到代码里面,只需要稍微修改 fft:
可见成功复原了我们原多项式 $1+2x+3x^2 + 4 x ^ 3 + 5x ^ 4 + 6 x ^ 5$ 的系数。
多项式乘法的代码实现
现在考虑把两个多项式 $A, B$ 相乘。我们的 FFT 只能处理 2 的整次幂的情况,所以需要把 $A, B$ 的阶都扩展到 2 的某个幂,高位用 0 填充。此外,由于 $C = A\cdot B$ 有 $\deg A + \deg B$ 的阶,我们还得再把求值点的个数扩充得比它多(否则就没法恢复 $C$ 的那么多系数了)。
多项式乘法的模板题是 洛谷P3803 【模板】多项式乘法(FFT)。由于洛谷没有 SageMath,代码稍微改动了一下,仅依赖 numpy.
import numpy as np
PI = 3.1415926535
def fft(A):
n = len(A)
if n==1:
return [A[0]]
A0 = fft(A[0::2])
A1 = fft(A[1::2])
res = np.zeros(n, dtype = np.complex)
omega = np.cos(2.0*PI/n) + np.sin(2.0*PI/n)*1.0j
x = 1
for k in range(n//2):
res[k] = A0[k] + x * A1[k]
res[k+n//2] = A0[k] - x * A1[k]
x *= omega
return res
def ifft(A):
n = len(A)
if n==1:
return [A[0]]
A0 = ifft(A[0::2])
A1 = ifft(A[1::2])
res = np.zeros(n, dtype = np.complex)
omega = np.cos(2.0*PI/n) - np.sin(2.0*PI/n)*1.0j
x = 1
for k in range(n//2):
res[k] = A0[k] + x * A1[k]
res[k+n//2] = A0[k] - x * A1[k]
x *= omega
return res
def times(A, B):
lenC = len(A) + len(B) - 1
pts = 1
while pts < lenC:
pts *= 2
X = np.zeros(pts, dtype=float)
X[0:len(A)] = A
Y = np.zeros(pts, dtype=float)
Y[0:len(B)] = B
Z = ifft(fft(X) * fft(Y))[:lenC] / pts
print(' '.join([str(int(x)) for x in Z.real.round()]))
_, _ = input().split()
A = [int(x) for x in input().split()]
B = [int(x) for x in input().split()]
times(A, B)