天天看點

洛谷P4719 【模闆】動态dp 矩陣乘法+樹鍊剖分+線段樹題目描述題目分析代碼

題目描述

給定一棵 n n 個點的樹,點帶點權。

有 mm 次操作,每次操作給定 x,y x , y ,表示修改點 x x 的權值為 yy 。

你需要在每次操作之後求出這棵樹的最大權獨立集的權值大小。

題目分析

假如沒有修改操作,這題怎麼做呢?設 ax a x 為 x x 的點權,f(x,0/1)f(x,0/1)表示 x x 這個點不選/選的情況下,其子樹中的最大權獨立集權值大小。那麼就有:

f(x,0)=∑max(f(son,0),f(son,1))f(x,0)=∑max(f(son,0),f(son,1))

f(x,1)=ax+∑f(son,0) f ( x , 1 ) = a x + ∑ f ( s o n , 0 )

現在有了修改操作,就很頭疼。一想到動态修改,就想到以線段樹為代表的資料結構,一想到在樹上,就想到樹鍊剖分。既然是動态DP,那麼有一個比較“動感”的東西可以處理DP——

“矩陣乘法”!

為什麼說它比較動感呢,是因為将一個DP寫成矩乘的形式後,矩乘又是有結合率的,是以可以先算前一半再算後一半再合起來什麼的。

考慮将矩乘中的乘法改成加法,加法改成取max操作。

matrix operator * (matrix a,matrix b) {
    matrix c;
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    return c;
}
           

設 g(x,0/1) g ( x , 0 / 1 ) 表示對于一個點 x x ,在不選/選它的情況下,其輕兒子對其造成的貢獻。又因為重兒子的dfs序是xx的dfs序+1,設 x x 的dfs序為1,則有:

[f(i,0)f(i,1)]=[g(i,0)g(i,1)g(i,0)0][f(i+1,0)f(i+1,1)][f(i,0)f(i,1)]=[g(i,0)g(i,0)g(i,1)0][f(i+1,0)f(i+1,1)]

然後每一次修改操作,我們修改若幹條重鍊,每條重鍊修改一個點的轉移矩陣即可。

一條重鍊頂端的答案,直接查詢這條鍊對應的dfs序區間的轉移矩陣全部相乘的結果。這樣也可以獲得1的答案。

代碼

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=,w=;char ch=' ';
    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
    if(ch=='-') w=-,ch=getchar();
    while(ch>='0'&&ch<='9') q=q*+ch-'0',ch=getchar();
    return q*w;
}
typedef long long LL;
const int N=;
int n,m,tot,tim;
int h[N],ne[N<<],to[N<<],dep[N],fa[N],sz[N];
int pos[N],repos[N],ed[N],top[N],son[N];
LL a[N],f[N][];
void add(int x,int y) {to[++tot]=y,ne[tot]=h[x],h[x]=tot;}
void dfs1(int x,int las) {
    fa[x]=las,dep[x]=dep[las]+,sz[x]=;
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las) dfs1(to[i],x),sz[x]+=sz[to[i]];
}
void dfs2(int x,int las) {
    int bj=,mx=;
    pos[x]=++tim,repos[tim]=x;
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las&&sz[to[i]]>mx) mx=sz[to[i]],bj=to[i];
    if(!bj) {ed[top[x]]=pos[x];return;}
    son[x]=bj,top[bj]=top[x],dfs2(bj,x);
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las&&to[i]!=bj) top[to[i]]=to[i],dfs2(to[i],x);
}
void dp(int x,int las) {
    f[x][]=a[x];
    for(RI i=h[x];i;i=ne[i]) {
        if(to[i]==las) continue;
        dp(to[i],x);
        f[x][]+=max(f[to[i]][],f[to[i]][]);
        f[x][]+=f[to[i]][];
    }
}

struct matrix{LL t[][];}tr[N<<],QvQ[N];
matrix operator * (matrix a,matrix b) {
    matrix c;
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    c.t[][]=max(a.t[][]+b.t[][],a.t[][]+b.t[][]);
    return c;
}
void build(int s,int t,int i) {
    if(s==t) {
        int x=repos[s];LL g0=,g1=;
        for(RI j=h[x];j;j=ne[j])
            if(to[j]!=fa[x]&&to[j]!=son[x])
                g0+=max(f[to[j]][],f[to[j]][]),g1+=f[to[j]][];
        tr[i].t[][]=tr[i].t[][]=g0,tr[i].t[][]=g1+a[x];
        QvQ[s]=tr[i];
        return;
    }
    int mid=(s+t)>>;
    build(s,mid,i<<),build(mid+,t,(i<<)|);
    tr[i]=tr[i<<]*tr[(i<<)|];
}
void chan(int x,int s,int t,int i) {
    if(s==t) {tr[i]=QvQ[s];return;}
    int mid=(s+t)>>;
    if(x<=mid) chan(x,s,mid,i<<);
    else chan(x,mid+,t,(i<<)|);
    tr[i]=tr[i<<]*tr[(i<<)|];
}
matrix query(int l,int r,int s,int t,int i) {
    if(l<=s&&t<=r) return tr[i];
    int mid=(s+t)>>;
    if(r<=mid) return query(l,r,s,mid,i<<);
    if(mid+<=l) return query(l,r,mid+,t,(i<<)|);
    return query(l,r,s,mid,i<<)*query(l,r,mid+,t,(i<<)|);
}
matrix getans(int x) {return query(pos[x],ed[x],,n,);}//獲得一條鍊頂端的dp值
void work(int x,LL num) {//修改的主體
    QvQ[pos[x]].t[][]+=num-a[x],a[x]=num;
    matrix k1,k2;
    while(x) {//往上跳
        k1=getans(top[x]),chan(pos[x],,n,),k2=getans(top[x]);
        x=fa[top[x]];if(!x) break;
        //修改一個點的轉移矩陣
        QvQ[pos[x]].t[][]+=max(k2.t[][],k2.t[][])-max(k1.t[][],k1.t[][]);
        QvQ[pos[x]].t[][]=QvQ[pos[x]].t[][];
        QvQ[pos[x]].t[][]+=k2.t[][]-k1.t[][];
    }
}
int main()
{
    int x,y;
    n=read(),m=read();
    for(RI i=;i<=n;++i) a[i]=read();
    for(RI i=;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
    dfs1(,),top[]=,dfs2(,),dp(,);
    build(,n,);
    while(m--) {
        x=read(),y=read();
        work(x,y);matrix kl=getans();
        printf("%lld\n",max(kl.t[][],kl.t[][]));
    }
    return ;
}
           

繼續閱讀