前言:
重看了下fft遞歸實作,好像不難了解,以前真的太naive。然後被疊代版的各種吊打,趕緊補下,順便學下ntt。
哈爾濱真的冷,讓我這個GD蒟蒻怎麼碼代碼
FFT:
先貼連結:快速傅裡葉變換FFT的疊代實作
這篇部落格講的很清楚了,本質上是一樣的,就是将底層排好序後,自底向上一層層求。
至于為什麼這麼排:顯然
代碼留坑。
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<cmath>
#include<iostream>
using namespace std;
const double pi=acos(-);
int n,m,bin[];
complex <double>a[],b[],c[];
void fft(complex <double> *a,int n,int op)
{
for(int i=;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
for(int i=;i<n;i<<=)
{
complex <double> wn(cos(pi/i),op*sin(pi/i)),t;
for(int j=;j<n;j+=i<<)
{
complex <double>w(,);
for(int k=;k<i;k++)
{
t=w*a[i+j+k];w*=wn;
a[i+j+k]=a[j+k]-t;a[j+k]=a[j+k]+t;
}
}
}
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=;i<=n;i++) scanf("%lf",&a[i]);
for(int i=;i<=m;i++) scanf("%lf",&b[i]);
m+=n;n=;while(n<=m) n<<=;
for(int i=;i<n;i++) bin[i]=(bin[i>>]>>)|((i&)*(n>>));
fft(a,n,);fft(b,n,);
for(int i=;i<n;i++) c[i]=a[i]*b[i];
fft(c,n,-);
for(int i=;i<=m;i++) printf("%d ",(int)(c[i].real()/n+));
}
NTT:
用原根代替機關複根,可以支援 mod m o d 操作。
用NTT時要求模數比較特殊也許是我太弱
設模數為 p=x∗2N+1(N>=logn) p = x ∗ 2 N + 1 ( N >= l o g n ) 且是個質數, g g 為pp的原根。
因為 gp−1=1(mod p) g p − 1 = 1 ( m o d p ) 設 gn=gp−1n g n = g p − 1 n 代替 wn w n
顯然也是滿足 wn w n 的幾條性質的。(相消,折半等)
于是愉快的上代碼。
code:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#define LL long long
using namespace std;
LL a[],b[],c[];
int p=,g=,n,m,bin[];
LL pow(LL a,int b,int mod)
{
LL ans=;
while(b)
{
if(b&) ans=ans*a%mod;
a=a*a%mod;b>>=;
}
return ans;
}
void ntt(LL *a,int n,int op)
{
for(int i=;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
for(int i=;i<n;i<<=)
{
LL wn=pow((LL)g,op==?(p-)/(*i):p--(p-)/(*i),p),t,w;
for(int j=;j<n;j+=i<<)
{
w=;
for(int k=;k<i;k++)
{
t=w*a[i+j+k]%p;w=w*wn%p;
a[i+j+k]=(a[j+k]-t+p)%p;a[j+k]=(a[j+k]+t)%p;
}
}
}
if(op==-)
{
LL inv=pow(n,p-,p);
for(int i=;i<n;i++) a[i]=a[i]*inv%p;
}
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=;i<=n;i++) scanf("%lld",&a[i]);
for(int i=;i<=m;i++) scanf("%lld",&b[i]);
m+=n;n=;while(n<=m) n<<=;
for(int i=;i<n;i++) bin[i]=(bin[i>>]>>)|((i&)*(n>>));
ntt(a,n,);ntt(b,n,);
for(int i=;i<n;i++) c[i]=a[i]*b[i];ntt(c,n,-);
for(int i=;i<=m;i++) printf("%lld ",c[i]);
}