天天看点

FFT/NTT 总结(HDU 4656)

        FFT/NTT可以说是非常有名的两个东西,从通信到数学再到计算机领域……

        然后我本人接触这个东西也就是在今年暑假的时候。其实呢,到现在,它的变化本质和算法的实现方法我还并不是很了解。但是呢,作为一个ACM选手,能够做到会熟练的运用算法模板,懂得一些变通也就差不多能够算是合格了。而且,FFT/NTT本身也没什么,就是套用模板即可。

        下面我就具体说说FFT/NTT怎么用,在什么情况下使用,以及一些注意事项。

        首先从最简单的开始。我们在初学的时候,学过一些大整数的乘法,高精度,然后显然,这个复杂度是O(N^2)的,如果数字的位数长一点的话很容易就TLE了,那么有没有什么方法能够优化呢?答案是肯定的。根据大整数乘法的运算方式,我们很容易可以发现,运算发生正好符合卷积的运算,即每一位乘以另一个数字所有位。于是可以用FFT迎刃而解,复杂度为O(NlogN)。

        可以看出,FFT/NTT说白了,就是用来计算卷积的一个快速工具。有了这个工具可以解决很多问题。但是,关键是如何把一些问题转换为卷积的形式来计算。在之前多校赛的时候就曾经出过一道题目,具体可以看我的博客。

        下面,我就HDU 4656具体的来推导一下。我还是上图吧,字丑勿喷……

FFT/NTT 总结(HDU 4656)

        在以上的推导中,由于最后的表达式比较复杂,所以我折这了一个辅助的和式P[j],首先利用一次FFT求出,然后再把这个作为一个小项带入最后表达式中。最后可以化成最后那个较为明显的卷积表达式。注意这题比较特殊,有一个巧妙的地方,那就是2kj=j^2+k^2-(j-k)^2,这样子就把2jk这个骨头给弄掉了。可以看出在刚才的推导过程中,那三个定理我都用到了。

        接下来就是注意事项了。关于FFT和NTT,这两个非常的相似,那么什么时候用FFT什么时候用NTT呢?其实很容易区分,一般来说,如果要取模的话,那么就要用NTT,否则就用FFT。然后关于NTT,它的模数可不是可以随便乱取的。根据和wh学长所说:

FFT/NTT 总结(HDU 4656)

        然后,一般来说常用的模数会给你998244353或者1004535809。那么问题来了,如果模数不是这两个,而是任意取的怎么办。这里参考网上的做法,就是利用中国剩余定理CRT,首先用刚刚所说的两个模数求出在他们剩余系下的两个解,然后再用CRT解同于方程即可。这个可以说非常的巧妙咯。具体还是见代码:

#include<bits/stdc++.h>
#define LL long long
#define N (1<<20)+10
#define Mod 1000003
#define M 100010
using namespace std;

int n,m,b,c,d,len,a[M],inv[M],fac[M];
int Pj[N],X[N],Y[N],A[N],B[N],ans[N];

namespace NTT
{
    const int m1=998244353,m2=1004535809,pwMod=Mod-1;
    const LL P=1002772198720536577LL;

    inline int CRT(int r1,int r2){
        return (((r1-r2)*334845110LL%m2*m1+r1)%P+P)%P%Mod;
    }

    int qpow(int a,int t,int P){
        int r=1;
        while(t){
            if(t&1)r=(LL)r*a%P;
            a=(LL)a*a%P;t>>=1;
        }
        return r;
    }

    int _wn[25],w1[N],w2[N],w3[N],w4[N];

    void NTT(int*A,int len,int dft,int P)
    {
        int i,j=len>>1,k,l,w,wn,c=0,u,v;
        for(i=0;i<=21;i++)_wn[i]=qpow(3,P-1>>i,P);
        for(i=1;i<len-1;i++)
        {
            if(i<j)swap(A[i],A[j]);
            for(k=len>>1;(j^=k)<k;k>>=1);
        }
        for(l=2;l<=len;l<<=1)
        {
            i=l>>1,wn=_wn[++c];
            for(j=0;j<len;j+=l){
                w=1;
                for(k=j;k<j+i;k++){
                    u=A[k],v=(LL)A[k+i]*w%P;
                    A[k]=(u+v)%P,A[k+i]=(u-v+P)%P;
                    w=(LL)w*wn%P;
                }
            }
        }
        if(dft==-1)
        {
            int inv_len=qpow(len,P-2,P);
            for(int i=0;i<len;i++)A[i]=(LL)A[i]*inv_len%P;
            for(int i=1;i<len/2;i++)swap(A[i],A[len-i]);
        }
    }

    void convol(int*A,int*B,int*R,int len)
    {
        memcpy(w1,A,len<<2);
        memcpy(w2,B,len<<2);
        NTT(w1,len,1,m1); NTT(w2,len,1,m1);
        for(int i=0;i<len;i++)w3[i]=(LL)w1[i]*w2[i]%m1;
        NTT(w3,len,-1,m1);
        memcpy(w1,A,len<<2);
        memcpy(w2,B,len<<2);
        NTT(w1,len,1,m2); NTT(w2,len,1,m2);
        for(int i=0;i<len;i++)w4[i]=(LL)w1[i]*w2[i]%m2;
        NTT(w4,len,-1,m2);
        for(int i=0;i<len;i++)R[i]=CRT(w3[i],w4[i]);
    }
}

void init()
{
    inv[0]=inv[1]=1;
    fac[0]=fac[1]=1;
    for(int i=2;i<M;i++)
    {
        fac[i]=(LL)fac[i-1]*i%Mod;
        inv[i]=Mod-(int)(Mod/i*(LL)inv[Mod%i]%Mod);                //???i???
    }
    for(int i=2;i<M;i++)
        inv[i]=(LL)inv[i-1]*inv[i]%Mod;
}

int main()
{
    init();
    while(~scanf("%d%d%d%d",&n,&b,&c,&d))
    {
        for(len=1;len<2*n;len<<=1);
        memset(X,0,len<<2); memset(Y,0,len<<2);
        for(int i=0;i<n;i++) scanf("%d",&a[i]);
        for(int i=0,y=1;i<n;i++,y=(LL)y*d%Mod)
            X[i]=(LL)a[n-i-1]*fac[n-i-1]%Mod,Y[i]=(LL)inv[i]*y%Mod;
        using namespace NTT; convol(X,Y,Pj,len);
        LL Inv=(LL)b;
        for(int i=0;i<n/2;i++) swap(Pj[i],Pj[n-i-1]);
        len<<=1; memset(A,0,len<<2); memset(B,0,len<<2);
        for(int i=0,x=1;i<n;i++,x=(LL)x*Inv%Mod)
            A[i]=(LL)x*Pj[i]%Mod*inv[i]%Mod*qpow(c,(LL)i*i%pwMod,Mod)%Mod;
        for(int i=0;i<2*n;i++)
            B[i]=qpow(c,(-(LL)(n-i)*(n-i)%pwMod+pwMod)%pwMod,Mod);
        convol(A,B,ans,len);
        for(int i=0;i<n;i++)
            printf("%d\n",(LL)ans[i+n]*qpow(c,(LL)i*i%pwMod,Mod)%Mod);
    }
    return 0;
}