天天看點

fft的疊代實作與ntt模闆

前言:

重看了下fft遞歸實作,好像不難了解,以前真的太naive。然後被疊代版的各種吊打,趕緊補下,順便學下ntt。

哈爾濱真的冷,讓我這個GD蒟蒻怎麼碼代碼

FFT:

先貼連結:快速傅裡葉變換FFT的疊代實作

這篇部落格講的很清楚了,本質上是一樣的,就是将底層排好序後,自底向上一層層求。

至于為什麼這麼排:顯然

代碼留坑。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<cmath>
#include<iostream>
using namespace std;
const double pi=acos(-);
int n,m,bin[];
complex <double>a[],b[],c[];
void fft(complex <double> *a,int n,int op)
{
    for(int i=;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
    for(int i=;i<n;i<<=)
    {
        complex <double> wn(cos(pi/i),op*sin(pi/i)),t;
        for(int j=;j<n;j+=i<<)
        {
            complex <double>w(,);
            for(int k=;k<i;k++)
            {
                t=w*a[i+j+k];w*=wn;
                a[i+j+k]=a[j+k]-t;a[j+k]=a[j+k]+t;
            }
        }
    }
}
int main()
{
    scanf("%d %d",&n,&m);
    for(int i=;i<=n;i++) scanf("%lf",&a[i]);
    for(int i=;i<=m;i++) scanf("%lf",&b[i]);
    m+=n;n=;while(n<=m) n<<=;
    for(int i=;i<n;i++) bin[i]=(bin[i>>]>>)|((i&)*(n>>));
    fft(a,n,);fft(b,n,);
    for(int i=;i<n;i++) c[i]=a[i]*b[i];
    fft(c,n,-);
    for(int i=;i<=m;i++) printf("%d ",(int)(c[i].real()/n+));
}
           

NTT:

用原根代替機關複根,可以支援 mod m o d 操作。

用NTT時要求模數比較特殊也許是我太弱

設模數為 p=x∗2N+1(N>=logn) p = x ∗ 2 N + 1 ( N >= l o g n ) 且是個質數, g g 為pp的原根。

因為 gp−1=1(mod p) g p − 1 = 1 ( m o d   p ) 設 gn=gp−1n g n = g p − 1 n 代替 wn w n

顯然也是滿足 wn w n 的幾條性質的。(相消,折半等)

于是愉快的上代碼。

code:

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#define LL long long
using namespace std;
LL a[],b[],c[];
int p=,g=,n,m,bin[];
LL pow(LL a,int b,int mod)
{
    LL ans=;
    while(b)
    {
        if(b&) ans=ans*a%mod;
        a=a*a%mod;b>>=;
    }
    return ans;
}
void ntt(LL *a,int n,int op)
{
    for(int i=;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
    for(int i=;i<n;i<<=)
    {
        LL wn=pow((LL)g,op==?(p-)/(*i):p--(p-)/(*i),p),t,w;
        for(int j=;j<n;j+=i<<)
        {
            w=;
            for(int k=;k<i;k++)
            {
                t=w*a[i+j+k]%p;w=w*wn%p;
                a[i+j+k]=(a[j+k]-t+p)%p;a[j+k]=(a[j+k]+t)%p;
            }
        }
    }
    if(op==-)
    {
        LL inv=pow(n,p-,p);
        for(int i=;i<n;i++) a[i]=a[i]*inv%p;
    }
}
int main()
{
    scanf("%d %d",&n,&m);
    for(int i=;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=;i<=m;i++) scanf("%lld",&b[i]);
    m+=n;n=;while(n<=m) n<<=;
    for(int i=;i<n;i++) bin[i]=(bin[i>>]>>)|((i&)*(n>>));
    ntt(a,n,);ntt(b,n,);
    for(int i=;i<n;i++) c[i]=a[i]*b[i];ntt(c,n,-);
    for(int i=;i<=m;i++) printf("%lld ",c[i]);
}