天天看点

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题

1.树链剖分的用处

对如下问题我们可以采用树链剖分的方法去做

1.把某个节点的子树的每个节点都加上一个值z

2.查询某个节点的子树的所有节点的值的和求出来

3.把一个节点x到y之间最短路径(经过边的条数最少)上的每个节点都加上某个值z

4.把一个节点x到y之间最短路径(经过边的条数最少)上的所有节点的和求出来

由于我们需要解决这些问题,所以我们要使用树链剖分这种算法。

2.实现原理:

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题

1.知识储备:

重儿子:该节点的所有儿子中,子树中节点个数最多的儿子。

举例:节点A有两个儿子,G所形成的树中有6个结点分别为 G B E V J T {GBEVJT} GBEVJT,C所形成的树中有5个结点 C F D Z X {CFDZX} CFDZX。所以A的重儿子为G。对于j结点E,它的两个儿子的大小相同,于是重儿子可以是任何一个儿子。

2.操作

我们按照每一个结点的重儿子走就形成了一条重链,如图

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题

我们要找出这个图上的所有重链,如图

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题

这个书上就两条重链,你找不出第3个了。

然后我们按重链的顺序为结点打上时间戳。重链按顺序打,其余的按照dfs的顺序打。如图

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题
我们定义这样几个数组完成重链剖分
int
fa[N],//记录该节点的父亲
dep[N],//记录该节点的深度
son[N],// 记录重儿子
siz[N],//记录该节点子树的大小(包含该节点)
top[N],//记录该重链的顶部
dfn[N],//时间戳
w[N],//
tim;//计数器
void dfs1(int u,int f)//处理fa,dep,siz,so
{
    dep[u]=dep[f]+1;
    fa[u]=f;
    siz[u]=1;//当前节点的大小初始化为1(因为该节点本身算一个嘛)
    int maxx=-1;
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==f)
            continue;//因为存的是无向图,所以一个边存两次,可能会回到该节点的父亲节点
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxx)
        {
            son[u]=v;
            maxx=siz[v];
        }

    }

}
void dfs2(int u,int t)//处理dfn,top,w,t为该链的头
    dfn[u]=++tim;
    top[u]=t;
    w[tim]=v[u];
    if(!son[u])
        return;
    dfs2(son[u],t);
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==fa[u]||v==son[u])
            continue;
        dfs2(v,v);
    }
}
           

操作完成。解释一下为什么要这样做。

通过这样做我们强行把树的结点拆成一个连续的数组: A G E J T B V C F D X Z {AGEJTBVCFDXZ} AGEJTBVCFDXZ

假设 A G E J T B V C F D X Z {AGEJTBVCFDXZ} AGEJTBVCFDXZ每个字母代表该节点存的值,然后你按照我说的来观察一下如果我要求以G为根的子树上的和(即实现操作2)是不是就是对数组 A G E J T B V C F D X Z {AGEJTBVCFDXZ} AGEJTBVCFDXZ从下标2到2+size-1求和。(2是我们按照数链剖分为G打的时间戳,size为以G为根的子树的大小。)

同理适用于所有的结点,你可以看一下。

那么我要实现操作1,与上面同理。

为什么可以这样呢?因为我们强行把树搞成了一个连续的,你观察一下是不是每棵子树都是一组连续的数。

到此为止,操作1和操作2就解释完了。下面来解释一下操作3和操作4怎么通过我们以上的处理来解决

3.把一个节点x到y之间最短路径(经过边的条数最少)上的每个节点都加上某个值z

4.把一个节点x到y之间最短路径(经过边的条数最少)上的所有节点的和求出来

以操作4为例,可以分成以下几种情况

看图:

树链剖分——重链剖分1.树链剖分的用处2.实现原理:洛谷重链剖分模板题

我们可以发现:任何一条路径都是由重链的一部分和叶子结点组成的。

1. x和y在一条重链上:

因为每个重链都是连续的所以我们可以直接求。eg:求G+E+J我们就可以在数组 A G E J T B V C F D X Z AGEJTBVCFDXZ AGEJTBVCFDXZ中求下标2~4的和。

2. 如果不在一条重链上:

引理:除根节点以外的任何一个结点的父亲都在一条重链上。

证明:因为父亲节点存在儿子所以一定存在重儿子,所以一定在一条重链上。

我嘴笨,还是举例子比较好讲。

对X到J最短路径进行求和,我们要维护两个指针

我们选所在链顶端较深的那一个进行操作,J的顶端为G,X的顶端为F,显然F的深度较深,所以我们把指向X的指针从X跳到F,边跳边求和,加到F后再往上跳,跳到F的父亲节点。由引理可知:P还是在一条重链上。于是无限循环,直到两指针跳到同一个结点货同一个重链上就解决了问题。

划重点:我们可以套一个线段树实现上面的区间求和,和区间修改操作。

洛谷重链剖分模板题

题解:

#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
const double PI = atan(1.0)*4.0;
typedef long long ll;
const int N=1e5+50;
int mod;
int dfn[N],dep[N],top[N],fa[N],son[N],siz[N],w[N],tim;
struct node
{
    int l,r;
    ll data;
    int plz;
} a[4*N];

void build(int now,int l,int r)
{
    a[now].plz=0;
    a[now].l=l;
    a[now].r=r;
    if(l==r)
    {
        a[now].data=w[l];
        return;
    }
    int mid=l+r>>1;
    build(now<<1,l,mid);
    build(now<<1|1,mid+1,r);
    a[now].data=a[now<<1].data+a[now<<1|1].data;
}
void pushdown(int now)
{
    if(a[now].plz!=0)
    {
        a[now<<1].plz+=a[now].plz;
        a[now<<1|1].plz+=a[now].plz;
        a[now<<1].data+=a[now].plz*(a[now<<1].r-a[now<<1].l+1);
        a[now<<1|1].data+=a[now].plz*(a[now<<1|1].r-a[now<<1|1].l+1);
        a[now].plz=0;
    }
    return;
}
void updata(int now,int l,int r,int val)
{
    if(l>a[now].r||r<a[now].l)
        return ;
    if(a[now].l>=l&&a[now].r<=r)
    {
        a[now].plz+=val;
        a[now].data+=val*(a[now].r-a[now].l+1);
        return;

    }

    pushdown(now);
    if(a[now<<1].r>=l)
        updata(now<<1,l,r,val);
    if(a[now<<1|1].l<=r)
        updata(now<<1|1,l,r,val);
        a[now].data=a[now<<1].data+a[now<<1|1].data;

}
ll query(int now,int l,int r)
{
    if(l>a[now].r||r<a[now].l)
        return 0;
    if(a[now].l>=l&&a[now].r<=r)
        return a[now].data;
    pushdown(now);
    ll ans=0;
    if(a[now<<1].r>=l)
        ans+=query(now<<1,l,r);
    if(a[now<<1|1].l<=r)
        ans+=query(now<<1|1,l,r);
    return ans%mod;
}
struct Node
{
    int u,v,next;
} edge[2*N];
int head[N],cnt;
void add(int u,int v)
{
    edge[cnt].u=u;
    edge[cnt].v=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
    cnt++;
}
int v[N];

void dfs1(int u,int f)//fa,dep,son siz
{
    dep[u]=dep[f]+1;
    fa[u]=f;
    siz[u]=1;
    int maxx=-1;
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==f)
            continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxx)
        {
            son[u]=v;
            maxx=siz[v];
        }

    }

}
void dfs2(int u,int t)//dfn,top
{
    dfn[u]=++tim;
    top[u]=t;
    w[tim]=v[u];
    if(!son[u])
        return;
    dfs2(son[u],t);
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==fa[u]||v==son[u])
            continue;
        dfs2(v,v);
    }
}
void updatachain(int x,int y,int z)
{
    z=z%mod;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        updata(1,dfn[top[x]],dfn[x],z);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    updata(1,dfn[x],dfn[y],z);
}
int querychain(int x,int y)
{
    int res=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        res=(res+query(1,dfn[top[x]],dfn[x]))%mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    res=(res+query(1,dfn[x],dfn[y]))%mod;
    return res%mod;
}

void init()
{
    for(int i=1; i<=N; i++)
        head[i]=-1;
    cnt=0;
    tim=0;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    freopen("E:\\in.txt","r",stdin);
    init();
    int n,m,r;
    cin>>n>>m>>r>>mod;
    for(int i=1; i<=n; i++)
    {
        cin>>v[i];
    }

    for(int i=1; i<n; i++)
    {
        int x,y;
        cin>>x>>y;
        add(x,y);
        add(y,x);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    for(int i=1; i<=m; i++)
    {
        int p,x,y,z;
        cin>>p;
        if(p==1)
        {
            cin>>x>>y>>z;
            updatachain(x,y,z);
        }
        if(p==2)
        {
            cin>>x>>y;
            cout<<querychain(x,y)<<endl;
        }
        if(p==3)
        {
            cin>>x>>z;
            updata(1,dfn[x],dfn[x]+siz[x]-1,z);
        }
        if(p==4)
        {
            cin>>x;
            cout<<query(1,dfn[x],dfn[x]+siz[x]-1)<<endl;
        }
    }


}
           

继续阅读