天天看點

任意模數NTT學習小記問題拆系數FFT将FFT中DFT的次數優化為兩次對拆系數FFT的優化代碼

問題

求兩個多項式 A ( x ) A(x) A(x)和 B ( x ) B(x) B(x)對一個不是NTT模數的數取模的結果。

拆系數FFT

設定一個門檻值 W W W(通常設定為 2 15 2^{15} 215),将 A ( x ) , B ( x ) A(x),B(x) A(x),B(x)拆分成 A = a W + b , B = c W + d A=aW+b,B=cW+d A=aW+b,B=cW+d,其中 a , b , c , d a,b,c,d a,b,c,d均為多項式。則 A B = a c W 2 + ( a d + b c ) W + b d AB=acW^2+(ad+bc)W+bd AB=acW2+(ad+bc)W+bd

需要做 7 7 7次DFT。

将FFT中DFT的次數優化為兩次

對于兩個實系數多項式 A ( x ) , B ( x ) A(x),B(x) A(x),B(x),現在要計算它們的乘積。設 F ( x ) = A ( x ) + i B ( x ) , G ( x ) = A ( x ) − i B ( x ) F(x)=A(x)+iB(x),\quad G(x)=A(x)-iB(x) F(x)=A(x)+iB(x),G(x)=A(x)−iB(x)

令 ω \omega ω表示 n n n次機關根, F D F T ( k ) = F ( ω k ) , G D F T ( k ) = G ( ω k ) F_{DFT}(k)=F(\omega^k),G_{DFT}(k)=G(\omega^k) FDFT​(k)=F(ωk),GDFT​(k)=G(ωk)。則顯然有 F D F T ( k ) = A ( ω k ) + i B ( ω k ) = ∑ j = 0 n − 1 ( a j + i b j ) ω j k F_{DFT}(k)=A(\omega^k)+iB(\omega^k)=\sum_{j=0}^{n-1}(a_j+ib_j)\omega^{jk} FDFT​(k)=A(ωk)+iB(ωk)=j=0∑n−1​(aj​+ibj​)ωjk

考慮如何通過 F F F的點值得到 G G G的點值。令 c o n j ( w ) conj(w) conj(w)表示 w w w的共轭,那麼 G D F T ( k ) = ∑ j = 0 n − 1 ( a j − i b j ) ω j k = ∑ j = 0 n − 1 ( a j − i b j ) ( cos ⁡ 2 π j k n + i sin ⁡ 2 π j k n ) = ∑ j = 0 n − 1 ( a j cos ⁡ 2 π j k n + b j sin ⁡ 2 π j k n ) + i ( a j sin ⁡ 2 π j k n − b j cos ⁡ 2 π j k n ) = c o n j ( ∑ j = 0 n − 1 ( a j cos ⁡ − 2 π j k n − b j sin ⁡ − 2 π j k n ) + i ( a j sin ⁡ − 2 π j k n + b j cos ⁡ − 2 π j k n ) ) = c o n j ( ∑ j = 0 n − 1 ( a j + i b j ) ( cos ⁡ − 2 π j k n + i sin ⁡ − 2 π j k n ) ) = c o n j ( ∑ j = 0 n − 1 ( a j + i b j ) ω − j k ) = c o n j ( ∑ j = 0 n − 1 ( a j + i b j ) ω j ( n − k ) ) = c o n j ( F D F T ( n − k ) ) \begin{aligned} G_{DFT}(k)&=\sum_{j=0}^{n-1}(a_j-ib_j)\omega^{jk}\\ &=\sum_{j=0}^{n-1}(a_j-ib_j)\Big(\cos\frac{2\pi jk}{n}+i\sin\frac{2\pi jk}{n}\Big)\\ &=\sum_{j=0}^{n-1}\Big(a_j\cos\frac{2\pi jk}{n}+b_j\sin\frac{2\pi jk}{n}\Big)+i\Big(a_j\sin\frac{2\pi jk}{n}-b_j\cos\frac{2\pi jk}{n}\Big)\\ &=conj\Big(\sum_{j=0}^{n-1}\Big(a_j\cos\frac{-2\pi jk}{n}-b_j\sin\frac{-2\pi jk}{n}\Big)+i\Big(a_j\sin\frac{-2\pi jk}{n}+b_j\cos\frac{-2\pi jk}{n}\Big)\Big)\\ &=conj\Big(\sum_{j=0}^{n-1}(a_j+ib_j)\Big(\cos\frac{-2\pi jk}{n} + i\sin\frac{-2\pi jk}{n}\Big)\Big)\\ &=conj\Big(\sum_{j=0}^{n-1}(a_j+ib_j)\omega^{-jk}\Big)\\ &=conj\Big(\sum_{j=0}^{n-1}(a_j+ib_j)\omega^{j(n-k)}\Big)\\ &=conj\Big(F_{DFT}(n-k)\Big) \end{aligned} GDFT​(k)​=j=0∑n−1​(aj​−ibj​)ωjk=j=0∑n−1​(aj​−ibj​)(cosn2πjk​+isinn2πjk​)=j=0∑n−1​(aj​cosn2πjk​+bj​sinn2πjk​)+i(aj​sinn2πjk​−bj​cosn2πjk​)=conj(j=0∑n−1​(aj​cosn−2πjk​−bj​sinn−2πjk​)+i(aj​sinn−2πjk​+bj​cosn−2πjk​))=conj(j=0∑n−1​(aj​+ibj​)(cosn−2πjk​+isinn−2πjk​))=conj(j=0∑n−1​(aj​+ibj​)ω−jk)=conj(j=0∑n−1​(aj​+ibj​)ωj(n−k))=conj(FDFT​(n−k))​

把 F , G F,G F,G的點值求出來,就可以解出 A , B A,B A,B的點值。進而把FFT中DFT的次數從三次優化為兩次。

對拆系數FFT的優化

從上面的推導容易發現,若知道一個多項式的點值,則可以求出其共轭多項式的點值。将這個優化用于拆系數FFT。

注意到 ( a + b i ) ( c + d i ) = ( a c − b d ) + i ( a d + b c ) (a+bi)(c+di)=(ac-bd)+i(ad+bc) (a+bi)(c+di)=(ac−bd)+i(ad+bc)

( a − b i ) ( c + d i ) = ( a c + b d ) + i ( a d − b c ) (a-bi)(c+di)=(ac+bd)+i(ad-bc) (a−bi)(c+di)=(ac+bd)+i(ad−bc)

可以先求出 ( a + b i ) (a+bi) (a+bi)和 ( c + d i ) (c+di) (c+di)的點值。由于 ( a + b i ) (a+bi) (a+bi)和 ( a − b i ) (a-bi) (a−bi)互為共轭,可以通過 ( a + b i ) (a+bi) (a+bi)的點值推出 ( a − b i ) (a-bi) (a−bi)的點值。求出上面兩式後再IDFT回去,就可以得到 a c , b d , a d + b c ac,bd,ad+bc ac,bd,ad+bc的值了。這樣總共需要做四次DFT。

代碼

#include<bits/stdc++.h>
using namespace std;

typedef long long LL;

const int W = 32768;
const int N = 400005;
const long double pi = acos(-1);

int n, m, p, rev[N], L, ans[N];

struct com
{
    long double x,y;
    com operator + (const com & d) const {return (com){x + d.x, y + d.y};}
    com operator - (const com & d) const {return (com){x - d.x, y - d.y};}
    com operator * (const com & d) const {return (com){x * d.x - y * d.y, x * d.y + y * d.x};}
    com operator / (const long double & d) const {return (com){x / d, y / d};}
    com operator ~ () const {return (com){x, -y};}
}A[N], B[N], C[N];

void pre()
{
	int lg = 0;
	for (L = 1; L <= n + m; L <<= 1, lg++);
	for (int i = 0; i < L; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
}

LL num(long double x)
{
	return x < 0 ? (LL)(x - 0.5) : (LL)(x + 0.5);
}

void fft(com * a, int f)
{
	for (int i = 0; i < L; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int i = 1; i < L; i <<= 1)
	{
		com wn = (com){cos(pi / i), f * sin(pi / i)};
		for (int j = 0; j < L; j += (i << 1))
		{
			com w = (com){1, 0};
			for (int k = 0; k < i; k++)
			{
				com u = a[j + k], v = a[j + k + i] * w;
				a[j + k] = u + v; a[j + k + i] = u - v;
				w = w * wn;
			}
		}
	}
	if (f == -1) for (int i = 0; i < L; i++) a[i] = a[i] / L;
}

int main()
{
	scanf("%d%d%d", &n, &m, &p);
	for (int i = 0; i <= n; i++)
	{
		int x; scanf("%d", &x);
		A[i] = (com){x / W, x % W};
	}
	for (int i = 0; i <= m; i++)
	{
		int x; scanf("%d", &x);
		C[i] = (com){x / W, x % W};
	}
	pre();
	fft(A, 1); fft(C, 1);
	for (int i = 0; i < L; i++) B[i] = ~A[(L - i) % L];
	for (int i = 0; i < L; i++) A[i] = A[i] * C[i], B[i] = B[i] * C[i];
	fft(A, -1); fft(B, -1);
	for (int i = 0; i <= n + m; i++)
	{
		LL x1 = num(A[i].x), x2 = num(B[i].x), y1 = num(A[i].y), y2 = num(B[i].y);
		LL ac = (x1 + x2) / 2, bd = x2 - ac, bcad = y1;
		ans[i] = (ac % p * W % p * W % p + bcad * W % p + bd % p) % p;
	}
	for (int i = 0; i <= n + m; i++) printf("%d ", ans[i]);
	return 0;
}