天天看点

Codeforces 438E - The Child and Binary Tree 多项式求逆+开根

Description

我们的小朋友很喜欢计算机科学,而且尤其喜欢二叉树。

考虑一个含有n个互异正整数的序列c[1],c[2],…,c[n]。如果一棵带点权的有根二叉树满足其所有顶点的权值都在集合{c[1],c[2],…,c[n]}中,我们的小朋友就会将其称作神犇的。并且他认为,一棵带点权的树的权值,是其所有顶点权值的总和。

给出一个整数m,你能对于任意的s(1<=s<=m)计算出权值为s的神犇二叉树的个数吗?请参照样例以更好的理解什么样的两棵二叉树会被视为不同的。

我们只需要知道答案关于998244353(7*17*2^23+1,一个质数)取模后的值。

Solution

设答案序列的生成函数为 F ( x ) F(x) F(x),即 F ( x ) F(x) F(x)的第 i i i项表示权值为 i i i的答案。

而 G ( x ) G(x) G(x)的第 i i i项系数为 1 1 1,当且仅当存在某个 c = i c=i c=i,其他项系数都为 0 0 0。

那么显然有 F ( x ) = G ( x ) F 2 ( x ) + 1 F(x)=G(x)F^2(x)+1 F(x)=G(x)F2(x)+1, + 1 +1 +1是空树的情况。

然后可以直接用求根公式解方程,得到 F ( x ) = 1 ± 1 − 4 G ( x ) 2 G ( x ) F(x)={{1\pm \sqrt {1-4G(x)}}\over 2G(x)} F(x)=2G(x)1±1−4G(x)

​​。

上面只能取 − - −号,所以 F ( x ) = 2 1 + 1 − 4 G ( x ) F(x)={2\over{1+\sqrt {1-4G(x)}}} F(x)=1+1−4G(x)

​2​。

无论是符号的选取,还是分子有理化,都是为了求逆的时候 0 0 0次项不为 0 0 0。

然后直接套模板就行了。

Code

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=200010;
const int inf=2147483647;
const int mod=998244353,gn=3,inv2=499122177;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
int Pow(int x,int y)
{
    if(!y)return 1;
    int t=Pow(x,y>>1),re=(LL)t*t%mod;
    if(y&1)re=(LL)re*x%mod;
    return re;
}
int rev[Maxn<<2],tmp[Maxn<<2],invb[Maxn<<2];
void ntt(int *a,int n,int o)
{
    for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1)
    {
        int wn;
        if(o==1)wn=Pow(gn,(mod-1)/(i<<1));
        else wn=Pow(gn,mod-1-(mod-1)/(i<<1));
        for(int j=0;j<n;j+=(i<<1))
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int t=(LL)a[i+j+k]*w%mod;w=(LL)w*wn%mod;
                a[i+j+k]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(o==-1)
    {
        int inv=Pow(n,mod-2);
        for(int i=0;i<n;i++)a[i]=(LL)a[i]*inv%mod;
    }
}
void Inv(int *a,int *b,int n)//%x^n
{
    if(n==1){b[0]=Pow(a[0],mod-2);return;}
    Inv(a,b,n>>1);
    for(int i=0;i<n;i++)tmp[i]=a[i],tmp[i+n]=0;
    rev[0]=0;for(int i=1;i<(n<<1);i++)rev[i]=((rev[i>>1]>>1)|((i&1)*n));
    ntt(tmp,n<<1,1),ntt(b,n<<1,1);
    for(int i=0;i<(n<<1);i++)tmp[i]=(LL)b[i]*((2-(LL)tmp[i]*b[i]%mod+mod)%mod)%mod;
    ntt(tmp,n<<1,-1);
    for(int i=0;i<n;i++)b[i]=tmp[i],b[i+n]=0;
}
void Sqrt(int *a,int *b,int n)
{
    if(n==1){b[0]=1;return;}
    Sqrt(a,b,n>>1);
    for(int i=0;i<n;i++)invb[i]=0;
    Inv(b,invb,n);
    for(int i=0;i<n;i++)tmp[i]=a[i],tmp[i+n]=0;
    rev[0]=0;for(int i=1;i<(n<<1);i++)rev[i]=((rev[i>>1]>>1)|((i&1)*n));
    ntt(tmp,n<<1,1),ntt(invb,n<<1,1);
    for(int i=0;i<(n<<1);i++)tmp[i]=(LL)tmp[i]*inv2%mod*invb[i]%mod;
    ntt(tmp,n<<1,-1);
    for(int i=0;i<n;i++)b[i]=((LL)b[i]*inv2%mod+tmp[i])%mod;
}
int n,m,a[Maxn<<2],b[Maxn<<2],c[Maxn<<2];
int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++)
    {
        int x=read();
        if(x<=m)a[x]=mod-4;
    }a[0]=1;
    int t=1;while(t<=(m<<1))t<<=1;
    Sqrt(a,b,t);
    b[0]++;
    Inv(b,c,t);
    for(int i=1;i<=m;i++)printf("%lld\n",(LL)c[i]*2%mod);
}
           

继续阅读