天天看點

FFT&NTT 多項式乘法FFT&NTT 多項式乘法

FFT&NTT 多項式乘法

文章目錄

  • FFT&NTT 多項式乘法
    • 前言
    • 前置知識
      • 多項式的表示
      • 機關根
    • 離散傅裡葉變換(DFT)
    • 快速傅裡葉變換(FFT)
    • 離散傅裡葉逆變換(IDFT)
    • 快速傅裡葉逆變換
    • FTT實作
      • 優化
    • NTT
    • 多項式乘法封裝
    • 任意模數多項式乘法

前言

FFT,快速傅裡葉變換;NTT,快速數論變換,其實是一個東西在不同的域上的不同表現形式。本部落格隻是簡單地總結一下,提一些其它部落格沒有注意的地方。

推薦學習資料:

OI Wiki-FFT

OI Wiki-NTT

傅裡葉變換(FFT)學習筆記——command block (極度推薦)

NTT與多項式全家桶——command block

前置知識

多項式的表示

  1. 系數表示。要表示一個度為 n n n 的多項式,隻要 n + 1 n+1 n+1 個數表達 x i ( 0 ≤ i ≤ n ) x^i(0\le i\le n) xi(0≤i≤n) 項的系數即可。
  2. 點值表示。隻要 n + 1 n+1 n+1 個橫坐标不同的點,也可以表示這個多項式。這是因為代入 n + 1 n+1 n+1 個點,可以得到 n + 1 n+1 n+1 個方程,把 n + 1 n+1 n+1 個系數看成未知數,就變成了 國小二年級學過的 多元一次方程組啦。值得注意的是,這裡的點的橫縱坐标不必為實數,比如我們 FFT 用到的就是橫縱坐标都為複數的點。

如何快速計算乘法?如果是系數表示,我們需要 O ( n 2 ) \mathcal O(n^2) O(n2) 的複雜度。但是點值在這方面異常優秀,隻要 O ( n ) \mathcal O(n) O(n) 即可。

機關根

我們把複平面上機關圓 n n n 等分(以 ( 1 , 0 ) (1,0) (1,0) 作為等分的第一個點),會得到 n n n 個點。把這 n n n 個點對應的複數叫做 n n n 次機關根。記作 w n j w_n^j wnj​,其中 0 ≤ j < n 0\le j<n 0≤j<n。 w n j w_n^j wnj​ 的模為 1,輻角為 j 2 π \dfrac{j}{2\pi} 2πj​。于是有

w n j = exp ⁡ ( i 2 π j n ) = cos ⁡ 2 π j n + i sin ⁡ 2 π j n w_n^j=\exp(i\frac{2\pi j}{n})=\cos\dfrac {2\pi j}{n}+i\sin\dfrac {2\pi j}n wnj​=exp(in2πj​)=cosn2πj​+isinn2πj​

機關根有優美的性質:

  1. w n j = w n j + k n , k ∈ Z w_n^j=w_n^{j+kn},k\in\Z wnj​=wnj+kn​,k∈Z
  2. w n j = w 2 n 2 j w_n^j=w_{2n}^{2j} wnj​=w2n2j​
  3. w 2 n j + n = − w 2 n j w_{2n}^{j+n}=-w_{2n}^j w2nj+n​=−w2nj​

這些性質是我們利用 FFT 快速計算的基石。

離散傅裡葉變換(DFT)

離散傅裡葉變換,就是将系數表示變為點值表示(即“求值”)。這個大家都會, O ( n 2 ) \mathcal O(n^2) O(n2) 暴力代入啊。可惜太慢了。

快速傅裡葉變換(FFT)

快速傅裡葉變換利用 機關根 的性質,分治 的方法,在 O ( n log ⁡ n ) \mathcal O(n\log n) O(nlogn) 的時間内将一個度數為 n n n 的多項式由系數表示變為點值表示。它是 DFT的更新版。

比如一個度數為 n − 1 n-1 n−1 (這裡假設 n n n 是2的整數次幂)的多項式

f ( x ) = a 0 + a 1 x + a 2 x 2 + ⋯ + a n − 1 x n − 1 f(x)=a_0+a_1x+a_2x^2+\cdots+a_{n-1}x^{n-1} f(x)=a0​+a1​x+a2​x2+⋯+an−1​xn−1

我們分一下奇偶。

f ( x ) = ( a 0 + a 2 x 2 + ⋯ + a n − 2 x n − 2 ) + ( a 1 x + a 3 x 3 + ⋯ + a n − 1 x n − 1 ) f(x)=(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\cdots+a_{n-1}x^{n-1}) f(x)=(a0​+a2​x2+⋯+an−2​xn−2)+(a1​x+a3​x3+⋯+an−1​xn−1)

= ( a 0 + a 2 x 2 + ⋯ + a n − 2 x n − 2 ) + x ( a 1 + a 3 x 2 + ⋯ + a n − 1 x n − 2 ) =(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})+x(a_1+a_3x^2+\cdots+a_{n-1}x^{n-2}) =(a0​+a2​x2+⋯+an−2​xn−2)+x(a1​+a3​x2+⋯+an−1​xn−2)

我們記 f 1 ( x ) = a 0 + a 2 x + ⋯ + a n − 2 x n / 2 − 1 f_1(x)=a_0+a_2x+\cdots+a_{n-2}x^{n/2-1} f1​(x)=a0​+a2​x+⋯+an−2​xn/2−1, f 2 ( x ) = a 1 + a 3 x + ⋯ + a n − 1 x n / 2 − 1 f_2(x)=a_1+a_3x+\cdots+a_{n-1}x^{n/2-1} f2​(x)=a1​+a3​x+⋯+an−1​xn/2−1

則有 f ( x ) = f 1 ( x 2 ) + x f 2 ( x 2 ) f(x)=f_1(x^2)+xf_2(x^2) f(x)=f1​(x2)+xf2​(x2)

為了快速計算,我們帶入機關根。

  1. 先帶入個 w n k ( k < n / 2 ) w_n^k(k<n/2) wnk​(k<n/2)

有 f ( w n k ) = f 1 ( w n / 2 k ) + w n k f 2 ( w n / 2 k ) f(w_n^k)=f_1(w_{n/2}^{k})+w_n^kf_2(w_{n/2}^k) f(wnk​)=f1​(wn/2k​)+wnk​f2​(wn/2k​)

  1. 再帶入個 w n k + n / 2 = − w n k ( k < n / 2 ) w_n^{k+n/2}=-w_n^k(k<n/2) wnk+n/2​=−wnk​(k<n/2)

有 f ( w n k + n / 2 ) = f 1 ( w n / 2 k ) − w n k f 2 ( w n / 2 k ) f(w_n^{k+n/2})=f_1(w_{n/2}^k)-w_n^kf_2(w^k_{n/2}) f(wnk+n/2​)=f1​(wn/2k​)−wnk​f2​(wn/2k​)

我們發現,兩次代入有驚人的相似性。

于是可以分治計算了。每一層 n n n 的規模規模減半,顯然複雜度是 O ( n log ⁡ n ) O(n\log n) O(nlogn)。

代碼就不找了,随便一篇部落格都有。

離散傅裡葉逆變換(IDFT)

離散傅裡葉逆變換,就是将點值表示變為系數表示(即“插值”)。怎麼做?高斯消元是 O ( n 3 ) O(n^3) O(n3) 的。似乎不怎麼好做。

快速傅裡葉逆變換

還是請回我們的機關根吧,看看怎麼做。。。

結論:把點值( f ( w n k ) f(w_n^k) f(wnk​))當成系數,将 DFT中乘的那個 w n k w_n^k wnk​ 換成 − w n k -w_n^k −wnk​ 求點值,最後再都除以 n n n ,就是原多項式的系數啦。

證明去找别的部落格啦。

FTT實作

這裡加上了 “蝴蝶效應” 變成了疊代寫法,可以大幅度減小常數。

模闆:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
typedef double db;
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) ? EOF : *ss++)
ll read() {
	ll x = 0, f = 1; char ch = getchar();
	for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
	return x * f;
}
const int MAXN = 5e6 + 5;
const db Pi = acos(-1.0);
struct cp {db x, y;}f[MAXN], g[MAXN];
cp operator + (const cp& a, const cp& b) {return (cp){a.x+b.x, a.y+b.y};}
cp operator - (const cp& a, const cp& b) {return (cp){a.x-b.x, a.y-b.y};}
cp operator * (const cp& a, const cp& b) {return (cp){a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x};}
int n, m, d, id[MAXN];
void fft(cp* f, int fl) {
	for(int i = 0; i < d; i++) if(i < id[i]) swap(f[i], f[id[i]]);
	for(int l = 2, hl = 1; l <= d; l <<= 1, hl <<= 1) {
        //這是在枚舉哪一層,這裡的 l 就是推柿子時的 n
		cp w0 = (cp){cos(2*Pi / l), fl * sin(2*Pi / l)};
		for(int i = 0; i < d; i += l) {//i是每次疊代的段頭
			cp w = (cp){1, 0};
			for(int j = i; j < i+hl; j++, w = w * w0) {//j則是控制推柿子時的 k
				cp tt = w * f[j+hl];
				f[j+hl] = f[j] - tt;
				f[j] = f[j] + tt;
			}
		}
	}
	if(fl == -1) {//idft還得除以個 d(懶得寫數乘,就直接這樣寫了)
		for(int i = 0; i < d; i++) f[i].x /= d, f[i].y /= d;
	}
}
int main() {
	n = read(), m = read();
	for(int i = 0; i <= n; i++) f[i].x = read();
	for(int i = 0; i <= m; i++) g[i].x = read();
	for(d = 1; d <= n+m; d <<= 1);
	for(int i = 0; i <= d; i++) 
		id[i] = (id[i >> 1] >> 1) | ((i & 1) ? (d >> 1) : 0);
	fft(f, 1); fft(g, 1);
	for(int i = 0; i < d; i++) f[i] = f[i] * g[i];
	fft(f, -1);
	for(int i = 0; i <= n+m; i++) printf("%d ", int(f[i].x + 0.5));
	return 0;
}
           

請全文背誦

注意事項:

  1. 數組空間請注意,要開到 n + m n+m n+m 的至少兩倍。
  2. 請注意精度誤差,如果 f f f 和 g g g 的數量級差很多,不妨先數乘到同一數量級再做。

優化

我們可以利用一下 f , g f,g f,g 系數的虛部為零的特點,“三次變兩次”。

我們構造一個系數為複數的多項式 h ( x ) = f ( x ) + i g ( x ) h(x)=f(x)+ig(x) h(x)=f(x)+ig(x),

那麼 h 2 ( x ) = f 2 ( x ) − g 2 ( x ) + i ⋅ ( 2 f ( x ) g ( x ) ) h^2(x)=f^2(x)-g^2(x)+i\cdot(2f(x)g(x)) h2(x)=f2(x)−g2(x)+i⋅(2f(x)g(x))

于是我們隻要構造 h h h ,讓其平方即可。隻要一次DFT和一次IDFT。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
typedef double db;
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) ? EOF : *ss++)
ll read() {
	ll x = 0, f = 1; char ch = getchar();
	for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
	return x * f;
}
const int MAXN = 5e6 + 5;
const db Pi = acos(-1.0);
struct cp{db x, y;}f[MAXN];
cp operator + (const cp& a, const cp& b) {return (cp){a.x+b.x, a.y+b.y};}
cp operator - (const cp& a, const cp& b) {return (cp){a.x-b.x, a.y-b.y};}
cp operator * (const cp& a, const cp& b) {return (cp){a.x*b.x - a.y*b.y, a.x*b.y + a.y * b.x};}
int n, m, d, id[MAXN];
void fft(cp* f, int fl) {
	for(int i = 0; i < d; i++) if(i < id[i]) swap(f[i], f[id[i]]);
	for(int l = 2, hl = 1; l <= d; l <<= 1, hl <<= 1) {
		cp w0 = (cp){cos(2*Pi / l), fl * sin(2*Pi / l)};
		for(int i = 0; i < d; i += l) {
			cp w = (cp){1, 0};
			for(int j = i; j < i + hl; j++, w = w * w0) {
				cp tt = w * f[j+hl];
				f[j+hl] = f[j] - tt;
				f[j] = f[j] + tt;
			}
		}
	}
	if(fl == -1) {
		for(int i = 0; i < d; i++) f[i].x /= d, f[i].y /= d;
	}
}
int main() {
	n = read(), m = read();
	for(int i = 0; i <= n; i++) f[i].x = read();
	for(int i = 0; i <= m; i++) f[i].y = read();
	for(d = 1; d <= n+m; d <<= 1);
	for(int i = 0; i < d; i++) id[i] = (id[i >> 1] >> 1) | ((i & 1) ? (d >> 1) : 0);
	fft(f, 1);
	for(int i = 0; i < d; i++) f[i] = f[i] * f[i];
	fft(f, -1);
	for(int i = 0; i < d; i++) f[i].y /= 2;
	for(int i = 0; i <= n+m; i++) printf("%d ", int(f[i].y + 0.5));
	return 0;
}

           

NTT

我們之前都是在複數域内搞東西,但如果在模意義下,系數可能較大,題目要求取模。這時 FFT 就無用武之地了。幸運的是,我們有完美的替代品:NTT。

這裡需要用到 原根。

我們可以把 g j ( p − 1 ) / n g^{j(p-1)/n} gj(p−1)/n 當成 n n n 次機關根 w n j w_n^j wnj​。

  1. w n j = w n j + k n , k ∈ Z w_n^j=w_n^{j+kn},k\in\Z wnj​=wnj+kn​,k∈Z
  2. w n j = w 2 n 2 j w_n^j=w_{2n}^{2j} wnj​=w2n2j​
  3. w 2 n j + n = − w 2 n j w_{2n}^{j+n}=-w_{2n}^j w2nj+n​=−w2nj​

我們仍然有這些性質成立。

于是我們可以把它直接看成是機關根,FFT就變成NTT了。所有運算在模意義下完成。

常見的模數和它的原根

主要記住:

998244353,原根是 3

1004535809,原根是 3

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) ? EOF : *ss++)
ll read() {
	ll x = 0, f = 1; char ch = getchar();
	for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
	return x * f;
}
const int MAXN = 5e6 + 5;
const int P = 998244353, G = 3, invG = 332748118;
ll pls(ll a, ll b) {return a + b < P ? a + b : a + b - P;}
ll mns(ll a, ll b) {return a < b ? a + P - b : a - b;}
ll mul(ll a, ll b) {return a * b % P;}
int n, m, d, id[MAXN];
ll f[MAXN], g[MAXN];
ll qpow(ll a, ll n) {
	ll ret = 1;
	for(; n; n >>= 1, a = mul(a, a)) 
		if(n & 1) ret = mul(ret, a);
	return ret;
}
void NTT(ll* f, int n, int fl) {
	for(int i = 0; i < n; i++) if(i < id[i]) swap(f[i], f[id[i]]);
	for(int l = 2, hl = 1; l <= n; l <<= 1, hl <<= 1) {
		ll g0 = qpow(fl == 1 ? G : invG, (P-1) / l);
		for(int i = 0; i < n; i += l) {
			ll gn = 1;
			for(int j = i; j < i + hl; j++, gn = mul(gn, g0)) {
				ll tt = mul(f[j+hl], gn);
				f[j+hl] = mns(f[j], tt);
				f[j] = pls(f[j], tt);
			}
		}
	}
	if(fl == -1) {
		ll invn = qpow(n, P-2);
		for(int i = 0; i < n; i++) f[i] = mul(f[i], invn);
	}
}
int main() {
	n = read(); m = read();
	for(int i = 0; i <= n; i++) f[i] = read();
	for(int i = 0; i <= m; i++) g[i] = read();
	for(d = 1; d <= n+m; d <<= 1);
	for(int i = 0; i < d; i++) id[i] = (id[i >> 1] >> 1) | ((i & 1) ? (d >> 1) : 0);
	NTT(f, d, 1); NTT(g, d, 1);
	for(int i = 0; i < d; i++) f[i] = mul(f[i], g[i]);
	NTT(f, d, -1);
	for(int i = 0; i <= n+m; i++) printf("%lld ", f[i]);
	printf("\n");
	return 0;
}
           

多項式乘法封裝

#define clr(f, s, e) memset(f+(s), 0x00, sizeof(int) * ((e) - (s)))
#define cpy(f, g, n) memcpy(g, f, sizeof(int) * (n))
const int MAXN = (1 << 18) + 1, bas = 1 << 18, P = 998244353, G = 3, invG = 332748118;
int pls(int a, int b) {return a + b < P ? a + b : a + b - P;}
int mns(int a, int b) {return a < b ? a + P - b : a - b;}
int mul(int a, int b) {return 1ll * a * b % P;}
int qpow(int a, int n) {int ret = 1; for(; n; n >>= 1, a = mul(a, a)) if(n & 1) ret = mul(ret, a); return ret;}
int tf, tr[MAXN], _g[2][MAXN], inv[MAXN];
void init() {
	inv[1] = 1; for(int i = 2; i < MAXN; i++) inv[i] = mul(P - P / i, inv[P % i]);
	for(int i = 0; i < bas; i++) {
		_g[1][i] = qpow(G, (P-1) / bas * i);
		_g[0][i] = qpow(invG, (P-1) / bas * i);
	}
}
int getlim(int n) {
	int lim = 1; for(; lim < n + n; lim <<= 1);
	return lim;
}
void tpre(int lim) {
	if(lim == tf) return ;
	tf = lim; for(int i = 0; i < lim; i++) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
}
void NTT(int* f, int lim, int fl) {
	tpre(lim); for(int i = 0; i < lim; i++) if(i < tr[i]) swap(f[i], f[tr[i]]);
	for(int l = 2, k = 1; l <= lim; l <<= 1, k <<= 1)
		for(int i = 0; i < lim; i += l)
			for(int j = i; j < i+k; j++) {
				ll tt = mul(f[j+k], _g[fl][(j-i) * (bas / l)]);
				f[j+k] = mns(f[j], tt);
				f[j] = pls(f[j], tt);
			}
	if(!fl)
		for(int i = 0; i < lim; i++) f[i] = mul(f[i], inv[lim]);
}
void Mul(int* f, int* g, int* h, int n) {
	static int a[MAXN], b[MAXN];
	int lim = getlim(n);
	cpy(f, a, n); clr(a, n, lim);
	cpy(g, b, n); clr(b, n, lim);
	NTT(a, lim, 1); NTT(b, lim, 1);
	for(int i = 0; i < lim; i++) h[i] = mul(a[i], b[i]);
	NTT(h, lim, 0); clr(h, n, lim);
}

           

任意模數多項式乘法

P4245 【模闆】任意模數多項式乘法

給 2 個多項式 F ( x ) , G ( x ) F(x),G(x) F(x),G(x),求 F ( x ) G ( x ) F(x)G(x) F(x)G(x)。系數對 p p p 取模,不保證 p p p 是 NTT 模數。

也就是MTT,使用 4 次 FFT 完成任意模數的多項式乘法。

設 K = 2 15 K=2^{15} K=215,我們把多項式每項系數分為兩部分(高低位)。

F ( x ) = K ⋅ F 1 ( x ) + F 0 ( x ) G ( x ) = K ⋅ G 1 ( x ) + G 0 ( x ) ∴ F ( x ) G ( x ) = K 2 ⋅ F 1 ( x ) G 1 ( x ) + K ⋅ [ F 1 ( x ) G 0 ( x ) + F 0 ( x ) G 1 ( x ) ] + F 0 ( x ) G 0 ( x ) F(x)=K\cdot F_1(x)+F_0(x) \\ G(x)=K\cdot G_1(x)+G_0(x) \\ \therefore F(x)G(x)=K^2\cdot F_1(x)G_1(x)+K\cdot [F_1(x)G_0(x)+F_0(x)G_1(x)]+F_0(x)G_0(x) F(x)=K⋅F1​(x)+F0​(x)G(x)=K⋅G1​(x)+G0​(x)∴F(x)G(x)=K2⋅F1​(x)G1​(x)+K⋅[F1​(x)G0​(x)+F0​(x)G1​(x)]+F0​(x)G0​(x)

如何快速得到這四個多項式的點值表示?

構造

P ( x ) = F 0 ( x ) + i G 0 ( x ) Q ( x ) = F 0 ( x ) − i G 0 ( x ) P(x)=F_0(x)+iG_0(x) \\ Q(x)=F_0(x)-iG_0(x) P(x)=F0​(x)+iG0​(x)Q(x)=F0​(x)−iG0​(x)

我們驚奇地發現:

D F T ( P ) [ j ] = P ( w n j ) = F 0 ( w n j ) + i G 0 ( w n j ) = ∑ k = 0 n − 1 F 0 [ k ] w n k j + i ∑ k = 0 n − 1 G 0 [ k ] w n k j = ∑ k = 0 n − 1 ( F 0 [ k ] + i G 0 [ k ] ) ( cos ⁡ ( 2 π k j n ) + i sin ⁡ ( 2 π k j n ) ) = ∑ k = 0 n − 1 ( F 0 [ k ] cos ⁡ ( 2 π k j n ) − G 0 [ k ] sin ⁡ ( 2 π k j n ) ) + i ∑ k = 0 n − 1 ( F 0 [ k ] sin ⁡ ( 2 π k j n ) + G 0 [ k ] sin ⁡ ( 2 π k j n ) ) \mathrm{DFT}(P)[j]=P(w_n^j)=F_0(w_n^j)+iG_0(w_n^j) \\ =\sum_{k=0}^{n-1}F_0[k]w_n^{kj}+i\sum_{k=0}^{n-1}G_0[k]w_n^{kj} \\ =\sum_{k=0}^{n-1}(F_0[k]+iG_0[k])(\cos(\dfrac {2\pi kj}{n})+i\sin(\dfrac{2\pi kj}n)) \\ =\sum_{k=0}^{n-1}(F_0[k]\cos(\dfrac{2\pi kj}n)-G_0[k]\sin(\dfrac {2\pi kj}n))+\\ i\sum_{k=0}^{n-1}(F_0[k]\sin(\dfrac {2\pi kj}n)+G_0[k]\sin(\dfrac {2\pi kj}n)) DFT(P)[j]=P(wnj​)=F0​(wnj​)+iG0​(wnj​)=k=0∑n−1​F0​[k]wnkj​+ik=0∑n−1​G0​[k]wnkj​=k=0∑n−1​(F0​[k]+iG0​[k])(cos(n2πkj​)+isin(n2πkj​))=k=0∑n−1​(F0​[k]cos(n2πkj​)−G0​[k]sin(n2πkj​))+ik=0∑n−1​(F0​[k]sin(n2πkj​)+G0​[k]sin(n2πkj​))

同理

D F T ( Q ) [ n − j ] = P ( w n − j ) = F 0 ( w n − j ) − i G 0 ( w n − j ) = ∑ k = 0 n − 1 F 0 [ k ] w n − k j − i ∑ k = 0 n − 1 G 0 [ k ] w n − k j = ∑ k = 0 n − 1 ( F 0 [ k ] − i G 0 [ k ] ) ( cos ⁡ ( 2 π k j n ) − i sin ⁡ ( 2 π k j n ) ) = ∑ k = 0 n − 1 ( F 0 [ k ] cos ⁡ ( 2 π k j n ) − G 0 [ k ] sin ⁡ ( 2 π k j n ) ) + i ∑ k = 0 n − 1 ( F 0 [ k ] sin ⁡ ( 2 π k j n ) + G 0 [ k ] sin ⁡ ( 2 π k j n ) ) \mathrm{DFT}(Q)[n-j]=P(w_n^{-j})=F_0(w_n^{-j})-iG_0(w_n^{-j}) \\ =\sum_{k=0}^{n-1}F_0[k]w_n^{-kj}-i\sum_{k=0}^{n-1}G_0[k]w_n^{-kj} \\ =\sum_{k=0}^{n-1}(F_0[k]-iG_0[k])(\cos(\dfrac {2\pi kj}{n})-i\sin(\dfrac{2\pi kj}n)) \\ =\sum_{k=0}^{n-1}(F_0[k]\cos(\dfrac{2\pi kj}n)-G_0[k]\sin(\dfrac {2\pi kj}n))+\\ i\sum_{k=0}^{n-1}(F_0[k]\sin(\dfrac {2\pi kj}n)+G_0[k]\sin(\dfrac {2\pi kj}n)) DFT(Q)[n−j]=P(wn−j​)=F0​(wn−j​)−iG0​(wn−j​)=k=0∑n−1​F0​[k]wn−kj​−ik=0∑n−1​G0​[k]wn−kj​=k=0∑n−1​(F0​[k]−iG0​[k])(cos(n2πkj​)−isin(n2πkj​))=k=0∑n−1​(F0​[k]cos(n2πkj​)−G0​[k]sin(n2πkj​))+ik=0∑n−1​(F0​[k]sin(n2πkj​)+G0​[k]sin(n2πkj​))

故 P P P 的第 j j j 項點值與 Q Q Q 的第 n − j n-j n−j 項點值共轭。

于是我們可以使用 1 次 FFT 得到 P ( x ) P(x) P(x) 和 Q ( x ) Q(x) Q(x) 的點值,再解方程就可得到 F 0 ( x ) F_0(x) F0​(x) 和 G 0 ( x ) G_0(x) G0​(x) 的點值。

同樣地可得到 F 1 ( x ) , G 1 ( x ) F_1(x),G_1(x) F1​(x),G1​(x),使用了 2 次FFT。

然後考慮怎麼求解 回系數。

構造

P ( x ) = F 1 ( x ) G 1 ( x ) + i ( F 1 ( x ) G 0 ( x ) + F 0 ( x ) G 1 ( x ) ) Q ( x ) = F 0 ( x ) G 0 ( x ) P(x)=F_1(x)G_1(x)+i(F_1(x)G_0(x)+F_0(x)G_1(x)) \\ Q(x)=F_0(x)G_0(x) P(x)=F1​(x)G1​(x)+i(F1​(x)G0​(x)+F0​(x)G1​(x))Q(x)=F0​(x)G0​(x)

做兩次 IDFT 即可。

#define clr(f, s, t) memset(f + (s), 0x00, sizeof(int) * ((t) - (s)))
#define cpy(f, g, n) memcpy(g, f, sizeof(int) * (n))
const int MAXN = (1 << 19) + 5, bas = 1 << 19;
const db PI = acos(-1.0);
int P;
int pls(int a, int b) {return a + b < P ? a + b : a + b - P;}
int mns(int a, int b) {return a < b ? a + P - b : a - b;}
int mul(int a, int b) {return 1ll * a * b % P;}
int qpow(int a, int n) {int ret = 1; for(; n; n >>= 1, a = mul(a, a)) if(n & 1) ret = mul(ret, a); return ret;}
struct cp {db x, y;};
cp operator + (const cp& a, const cp& b) {return (cp){a.x + b.x, a.y + b.y};}
cp operator - (const cp& a, const cp& b) {return (cp){a.x - b.x, a.y - b.y};}
cp operator * (const cp& a, const cp& b) {return (cp){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};}
cp operator * (const cp& a, const db& k) {return (cp){a.x * k, a.y * k};}
const cp I = (cp){0, 1};
cp _g[2][MAXN];
int tr[MAXN], tf;
void init() {
	for(int i = 0; i < bas; i++) {
		db a = cos(2 * PI * i / bas), b = sin(2 * PI * i / bas);
		_g[1][i] = (cp){a, b};
		_g[0][i] = (cp){a, -b};
	}
}
int getlim(int n) {
	int lim = 1; for(; lim < n + n; lim <<= 1);
	return lim;
}
void tpre(int lim) {
	if(tf == lim) return;
	tf = lim; for(int i = 0; i < lim; i++) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
}
ll tran(db x) {return ((ll)(x > 0 ? x + .5 : x - .5) % P + P) % P;}
void FFT(cp* f, int lim, int fl) {
	tpre(lim); for(int i = 0; i < lim; i++) if(i < tr[i]) swap(f[i], f[tr[i]]);
	for(int l = 2, k = 1; l <= lim; l <<= 1, k <<= 1)
		for(int i = 0; i < lim; i += l)
			for(int j = i; j < i+k; j++) {
				cp tt = f[j+k] * _g[fl][(j-i) * (bas / l)];
				f[j+k] = f[j] - tt;
				f[j] = f[j] + tt;
			}
	if(!fl) for(int i = 0; i < lim; i++) f[i].x /= lim, f[i].y /= lim;
}
void Mul(int* f, int* g, int* h, int n) {
	static cp f0[MAXN], f1[MAXN], g0[MAXN], g1[MAXN];
	int lim = getlim(n);
	for(int i = 0; i < n; i++) f0[i].x = f[i] >> 15, f0[i].y = f[i] & 32767;
	for(int i = 0; i < n; i++) g0[i].x = g[i] >> 15, g0[i].y = g[i] & 32767;
	for(int i = n; i < lim; i++) f0[i] = (cp){0, 0};
	for(int i = n; i < lim; i++) g0[i] = (cp){0, 0};
	FFT(f0, lim, 1); FFT(g0, lim, 1);
	for(int i = 0; i < lim; i++) {
		f1[i] = f0[i ? lim - i : 0], f1[i].y *= -1;
		g1[i] = g0[i ? lim - i : 0], g1[i].y *= -1;
	}
	for(int i = 0; i < lim; i++) {
		cp a = (f0[i] + f1[i]) * 0.5;		//f0
		cp b = (f1[i] - f0[i]) * 0.5 * I;	//f1
		cp c = (g0[i] + g1[i]) * 0.5;		//g0
		cp d = (g1[i] - g0[i]) * 0.5 * I;	//g1
		f0[i] = a * c + I * (a * d + b * c);
		g0[i] = b * d;
	}
	FFT(f0, lim, 0); FFT(g0, lim, 0);
	for(int i = 0; i < n; i++)
		h[i] = (1ll * tran(f0[i].x) * (1 << 30) + 1ll * tran(f0[i].y) * (1 << 15) % P + tran(g0[i].x)) % P;
	clr(h, n, lim);
}
           

繼續閱讀