天天看點

[HEOI2018] 秘密襲擊coat

Description

給定一棵 \(n\) 個點的樹,每個點有點權 \(d_i\) ,請對于樹上所有大于等于 \(k\) 個點的聯通塊,求出聯通塊中第 \(k\) 大的點權之和。\(n\le 1666,d_i\leq 1666\)。對 \(64123\) 取模。

Sol

先轉化一下題目:

如果權值為 \(i\) 的點在某個聯通塊中是第 \(k\) 大,那麼它對該聯通塊的貢獻就是 \(i\),不妨對于 \(1\sim i\),統計一下有多少聯通塊的第 \(k\) 大是 \(\ge i\) 的,發現對于每個聯通塊,設它的第 \(k\) 大為 \(j\),那麼這個 \(j\) 被統計了 \(j\) 遍,驚奇的發現這就是它本應有的貢獻。這就等價于 求有多少個聯通塊,使得 \(\ge i\) 的數有 \(\ge k\) 個。

是以問題變成了,對于每個權值 \(i\),求有多少聯通塊有 \(\ge k\) 個 \(\ge i\) 的數。

可以設計一個\(\text{DP}\),\(f[i][j][k]\) 表示點 \(i\) 為根的子樹中,有 \(k\) 個 \(\ge j\) 的包含點 \(i\) 的聯通塊數。

那麼轉移就是

\[

\begin{align}

f[i][j][k] &= \prod_{v \in son[i]} (f[v][j][k']+1) \ \ \ \ (d[i]<j,\sum k'=k)\\

f[i][j][k] &= \prod_{v \in son[i]} (f[v][j][k']+1) \ \ \ \ (d[i] \geqslant j,\sum k'=k-1)

\end{align}

\]

樹上背包可以優化為 \(O(n^3)\)。卡卡常即可通過。

考慮生成函數。

外層先枚舉一個權值 \(j\) ,設 \(F_a(x)=\sum\limits_{i=0}^n f[a][j][i]\cdot x^i\) ,這是一個 \(n\) 次多項式。

求答案時可以令 \(G_a(x)\) 表示 \(a\) 子樹中所有點的 \(F(x)\) 的和,然後求 \(G_{root}(x)\) 的第 \(k,\dots,n\) 項和。

那麼轉移就變成了

\[

F_a(x)=\left(\prod_{b\in son_a}(1+F_b(x))\right)*\begin{cases}1 & d_a \ge i\\x&d_a\lt i\end{cases}

\]

如果我們給定 \(x_0\),求出 \(F_a(x_0),G_a(x_0)\),那麼可以直接 \(O(n)\;\text{DP}\) 。

那我們可以枚舉 \(x_0=1\sim n+1\),每次\(\text{DP}\)一下求出點值,最後再用拉格朗日插值求答案不就好了。然而這樣并沒有跑的更快。

算一下現在的複雜度,外層枚舉權值 \(O(n)\),枚舉 \(x_0\) \(O(n)\),樹形\(\text{DP}\;O(n)\),總複雜度 \(O(n^3)\)。

寫一下僞代碼:

DP(now,i,x0)
    (f,g)=(1,0)
    for to in son(now)
        (f0,g0)=DP(to,i,x0)
        (f,g)=(f*(1+f0),g+g0)
    if(d[now]>=i)
        (f,g)=(f*x0,g)
    (f,g)=(f,g+f)
    return (f,g)           

可不可以把枚舉權值這個複雜度優化一下呢?

有個非常牛逼的科技叫整體\(\text{DP}\)。大概意思就是把許多次\(\text{DP}\)放在一起做。

一般用線段樹維護每個詢問,線段樹的第 \(i\) 個葉子結點存儲的值就是第 \(i\) 個詢問的答案,在合并的時候使用線段樹合并,更新一些\(\text{DP}\)值。

當然如果每個節點的線段樹都是一顆滿線段樹的話複雜度顯然不對,是以有個優化:如果線段樹上一個節點 \(x\) 的子樹中的所有詢問的答案都一樣,那麼隻需要保留 \(x\) 這個節點即可。

那回到這道題,就可以把枚舉 \(W\) 個權值當做 \(W\) 次詢問,然後就能用整體\(\text{DP}\)維護了。

具體來說,現在外層隻需要枚舉 \(x_0\),那麼在樹形\(\text{DP}\)到點 \(i\) 的時候,點 \(i\) 的線段樹中第 \(j\) 個葉子結點存儲的值 \(v_1,v_2\) 的含義就是,當 \(x=x_0\) 時,\(F_i(x)\) 的點值為 \(v_1\),\(G_i(x)\) 的點值為 \(v_2\)。

那我們看一下僞代碼中的每個操作都對應着線段樹的什麼操作:

  • (f,g)=(1,0)

    整體指派
  • (f,g)=(f*(1+f0),g+g0)

    對應項合并
  • if d[a]>=i / (f,g)=(f*x0,g)

    給 \(1\sim d[a]\) 項整體打标記
  • (f,g)=(f,g+f)

    整體打标記

是以問題就變成了,如何線上段樹上維護好标記。

考慮我們需要做什麼:

  1. 維護

    (f,g)

  2. f

    整體加 \(1\)
  3. f

    乘上

    f0

  4. f

    加到

    g

因為對應項相乘并不好做,是以我們考慮定義一個類似于矩陣乘法一樣的變換,\((a,b,c,d)\) 表示目前節點維護的

(f,g)=(a+b*f,g+c+d*f)

為什麼這麼定義大概是xjb湊出來的?

然後變換的乘法就可以根據定義輕松推出來了懶得寫了

機關變換就是 \((0,1,0,0)\),任何變換乘上該變換還為本身。

而每個點的

(f,g)

實際上就是維護出來的 \((b,d)\)。\((a,c)\) 的存在大概是維護标記的需要?

于是有了這個就能求出 \(1\sim n+1\) 的點值來了。

最後一步,就是插值,求出原多項式的系數了。

這一步可以多項式快速插值實作,但是太難寫,而且複雜度瓶頸不在這裡。直接拉格朗日插值就好。

然而怎麼求出來每項的系數呢?

拉格朗日的式子長這樣:\(\sum\limits_{i=1}^{n+1} y_i\left(\prod_{j\ne i}\frac{(x-x_j)}{(x_i-x_j)} \right)\)

觀察到分子之間的差别很小,可以提前背包算出來 \(\prod_j (x-x_j)\),轉移是枚舉目前選前面的 \(x\) 還是後邊的 \(-x_j\),\(f[i]=f[i-1]-f[i]*x_j\),分母可以預處理逆元求出來。

那現在問題就隻剩下了,分子多乘了一個 \(x-x_i\),我們要把這個退背包回去。

實際上也很簡單,因為 \(f[j]=f'[j-1]-f'[j]*x_i\),我們實際上要求的是 \(f'[j]\),那把 \(j\) 從小到大枚舉,然後移項一下就行了。

那求出來每項的系數之後,第 \(k\sim n\) 項的和就是答案了。

Code

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned int ui;
const int N=1670;
const int M=N*200;
const ui mod=64123;
#define ls ch[x][0]
#define rs ch[x][1]

int head[N],tot,dp[M],d[N];
ui ans[N],f[N],in[N],a[N],b[N];
int n,k,W,cnt,dc,rt[N],ch[M][2];

ui inc(ui x,ui y){
    return x+y>=mod?x+y-mod:x+y;
}

struct Node{
    ui a,b,c,d;
    Node(){}
    Node(ui aa,ui bb,ui cc,ui dd){a=aa,b=bb,c=cc,d=dd;}
    friend Node operator+(Node x,Node y){
        return Node(inc(y.a,x.a*y.b%mod),x.b*y.b%mod,inc(x.c+y.c,x.a*y.d%mod),inc(x.d,x.b*y.d%mod));
    }
}sum[M];

struct Edge{
    int to,nxt;
}edge[N<<1];

void add(int x,int y){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    head[x]=cnt;
}

int newnode(){
    int t=dc?dp[dc--]:++tot;
    return sum[t]=Node(0,1,0,0),t;
}

void del(int x){
    if(!x) return;
    dp[++dc]=x;
    del(ls),del(rs);
    ls=rs=0;
}

void pushdown(int x){
    if(!ls) ls=newnode();
    if(!rs) rs=newnode();
    sum[ls]=sum[ls]+sum[x];
    sum[rs]=sum[rs]+sum[x];
    sum[x]=Node(0,1,0,0);
}

void merge(int &x,int &y){
    if(!ch[x][0] and !ch[x][1]) 
        swap(x,y);
    if(!ch[y][0] and !ch[y][1])
         return sum[x]=sum[x]+Node(0,sum[y].a,sum[y].c,0),void();
    pushdown(x),pushdown(y);
    merge(ch[x][0],ch[y][0]),merge(ch[x][1],ch[y][1]);
}

void modify(int x,int l,int r,int ql,int qr,Node p){
    if(ql<=l and r<=qr) return sum[x]=sum[x]+p,void();
    int mid=l+r>>1; pushdown(x);
    if(ql<=mid) modify(ls,l,mid,ql,qr,p);
    if(mid<qr) modify(rs,mid+1,r,ql,qr,p);
}

void dfs(int now,int x0,int fa=0){
    rt[now]=newnode();sum[rt[now]]=Node(1,0,0,0);
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(to==fa) continue;
        dfs(to,x0,now);
        merge(rt[now],rt[to]);
        del(rt[to]);
    }
    modify(rt[now],1,W,1,d[now],Node(0,x0,0,0));
    modify(rt[now],1,W,1,W,Node(0,1,0,1));
    modify(rt[now],1,W,1,W,Node(1,1,0,0));
}

ui query(int x,int l,int r){
    if(l==r) return sum[x].c;
    int mid=l+r>>1; pushdown(x);
    return (query(ls,l,mid)+query(rs,mid+1,r))%mod;
}

ui Lagrange(){
    in[1]=1; f[0]=1;
    for(int i=2;i<=n+1;i++)
        in[i]=(mod-mod/i)*in[mod%i]%mod;
    for(int i=1;i<=n+1;i++)
        for(int j=n+1;~j;j--){
            if(j) f[j]=inc(f[j-1],f[j]*(mod-i)%mod);
            else f[j]=f[j]*(mod-i)%mod;
        }
    for(int i=1;i<=n+1;i++){
        ui res=ans[i]; memcpy(b,f,sizeof 4*(n+2));
        for(int j=1;j<=n+1;j++){
            if(i==j) continue;
            if(i>j) res=res*in[i-j]%mod;
            else res=res*(mod-in[j-i])%mod;
        }
        for(int j=0;j<=n+1;j++){
            if(!j) b[j]=mod-b[j]*in[i]%mod;
            else b[j]=inc(mod,b[j-1]-b[j])*in[i]%mod;
        }
        for(int j=0;j<=n;j++)
            a[j]=inc(a[j],b[j]*res%mod);
    } ui ans=0;
    for(int i=k;i<=n;i++) ans=inc(ans,a[i]);
    return ans;
}

signed main(){
    scanf("%d%d%d",&n,&k,&W);
    for(int i=1;i<=n;i++) 
        scanf("%d",&d[i]);
    for(int x,y,i=1;i<n;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    for(int i=1;i<=n+1;i++){
        dfs(1,i);
        ans[i]=query(rt[1],1,W);
        del(rt[1]);
    } printf("%u\n",Lagrange()); return 0;
}