天天看點

Codeforces 438E - The Child and Binary Tree 多項式求逆+開根

Description

我們的小朋友很喜歡計算機科學,而且尤其喜歡二叉樹。

考慮一個含有n個互異正整數的序列c[1],c[2],…,c[n]。如果一棵帶點權的有根二叉樹滿足其所有頂點的權值都在集合{c[1],c[2],…,c[n]}中,我們的小朋友就會将其稱作神犇的。并且他認為,一棵帶點權的樹的權值,是其所有頂點權值的總和。

給出一個整數m,你能對于任意的s(1<=s<=m)計算出權值為s的神犇二叉樹的個數嗎?請參照樣例以更好的了解什麼樣的兩棵二叉樹會被視為不同的。

我們隻需要知道答案關于998244353(7*17*2^23+1,一個質數)取模後的值。

Solution

設答案序列的生成函數為 F ( x ) F(x) F(x),即 F ( x ) F(x) F(x)的第 i i i項表示權值為 i i i的答案。

而 G ( x ) G(x) G(x)的第 i i i項系數為 1 1 1,當且僅當存在某個 c = i c=i c=i,其他項系數都為 0 0 0。

那麼顯然有 F ( x ) = G ( x ) F 2 ( x ) + 1 F(x)=G(x)F^2(x)+1 F(x)=G(x)F2(x)+1, + 1 +1 +1是空樹的情況。

然後可以直接用求根公式解方程,得到 F ( x ) = 1 ± 1 − 4 G ( x ) 2 G ( x ) F(x)={{1\pm \sqrt {1-4G(x)}}\over 2G(x)} F(x)=2G(x)1±1−4G(x)

​​。

上面隻能取 − - −号,是以 F ( x ) = 2 1 + 1 − 4 G ( x ) F(x)={2\over{1+\sqrt {1-4G(x)}}} F(x)=1+1−4G(x)

​2​。

無論是符号的選取,還是分子有理化,都是為了求逆的時候 0 0 0次項不為 0 0 0。

然後直接套模闆就行了。

Code

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=200010;
const int inf=2147483647;
const int mod=998244353,gn=3,inv2=499122177;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
int Pow(int x,int y)
{
    if(!y)return 1;
    int t=Pow(x,y>>1),re=(LL)t*t%mod;
    if(y&1)re=(LL)re*x%mod;
    return re;
}
int rev[Maxn<<2],tmp[Maxn<<2],invb[Maxn<<2];
void ntt(int *a,int n,int o)
{
    for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1)
    {
        int wn;
        if(o==1)wn=Pow(gn,(mod-1)/(i<<1));
        else wn=Pow(gn,mod-1-(mod-1)/(i<<1));
        for(int j=0;j<n;j+=(i<<1))
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int t=(LL)a[i+j+k]*w%mod;w=(LL)w*wn%mod;
                a[i+j+k]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(o==-1)
    {
        int inv=Pow(n,mod-2);
        for(int i=0;i<n;i++)a[i]=(LL)a[i]*inv%mod;
    }
}
void Inv(int *a,int *b,int n)//%x^n
{
    if(n==1){b[0]=Pow(a[0],mod-2);return;}
    Inv(a,b,n>>1);
    for(int i=0;i<n;i++)tmp[i]=a[i],tmp[i+n]=0;
    rev[0]=0;for(int i=1;i<(n<<1);i++)rev[i]=((rev[i>>1]>>1)|((i&1)*n));
    ntt(tmp,n<<1,1),ntt(b,n<<1,1);
    for(int i=0;i<(n<<1);i++)tmp[i]=(LL)b[i]*((2-(LL)tmp[i]*b[i]%mod+mod)%mod)%mod;
    ntt(tmp,n<<1,-1);
    for(int i=0;i<n;i++)b[i]=tmp[i],b[i+n]=0;
}
void Sqrt(int *a,int *b,int n)
{
    if(n==1){b[0]=1;return;}
    Sqrt(a,b,n>>1);
    for(int i=0;i<n;i++)invb[i]=0;
    Inv(b,invb,n);
    for(int i=0;i<n;i++)tmp[i]=a[i],tmp[i+n]=0;
    rev[0]=0;for(int i=1;i<(n<<1);i++)rev[i]=((rev[i>>1]>>1)|((i&1)*n));
    ntt(tmp,n<<1,1),ntt(invb,n<<1,1);
    for(int i=0;i<(n<<1);i++)tmp[i]=(LL)tmp[i]*inv2%mod*invb[i]%mod;
    ntt(tmp,n<<1,-1);
    for(int i=0;i<n;i++)b[i]=((LL)b[i]*inv2%mod+tmp[i])%mod;
}
int n,m,a[Maxn<<2],b[Maxn<<2],c[Maxn<<2];
int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++)
    {
        int x=read();
        if(x<=m)a[x]=mod-4;
    }a[0]=1;
    int t=1;while(t<=(m<<1))t<<=1;
    Sqrt(a,b,t);
    b[0]++;
    Inv(b,c,t);
    for(int i=1;i<=m;i++)printf("%lld\n",(LL)c[i]*2%mod);
}
           

繼續閱讀