天天看点

线段树树链剖分(点权)

模板题:​​树链剖分​​

参考博客:树链剖分详解(洛谷模板 P3384)

前置技能:线段树
#define lson (p<<1)
#define rson (p<<1|1)      
建树的时候需要注意的:
void build(int s,int t,int p)
{
    if(s==t)
    {
        d[p]=wt[s]%mod;
        /*此处wt[]是排序好之后的序号对应的权重*/
        return;
    }
    int m=(s+t)>>1;
    build(s,m,lson),build(m+1,t,rson);
    d[p]=(d[lson]+d[rson])%mod;
}      

树链剖分:

初始化工作:

int dep[maxn],f[maxn],siz[maxn],son[maxn];
int top[maxn],w[maxn<<1],id[maxn],tot=0,n;
void dfs1(int x,int fa,int deep)
{
    dep[x]=deep,f[x]=fa,siz[x]=1;
    int maxson=-1;
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        dfs1(v,x,deep+1);
        siz[x]+=siz[v];
        if(siz[v]>maxson) maxson=siz[v],son[x]=v;
    }
}
void dfs2(int u,int topf)
{
    id[u]=++tot;
    wt[tot]=w[u];
    top[u]=topf;
    if(!son[u]) return;
    dfs2(son[u],topf);
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}      
区间修改(最短路径上所有节点的值加 k):
void updrange(int x,int y,int k)
{
    k%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        update(id[top[x]],id[x],1,n,k,1);
        x=f[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    update(id[x],id[y],1,n,k,1);
}      
区间查询(最短路径上所有节点的值之和):
ll qrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans=(ans+getsum(id[top[x]],id[x],1,n,1)%mod)%mod;
        x=f[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans=(ans+getsum(id[x],id[y],1,n,1)%mod)%mod;
    return ans;
}      
子树内所有节点值都加上 k:
void updson(int x,int k)
{
    update(id[x],id[x]+siz[x]-1,1,n,k,1);
}      
子树内所有节点值之和:
ll qson(int x)
{
    return getsum(id[x],id[x]+siz[x]-1,1,n,1)%mod;
}      

模板题代码:

// Created by CAD on 2019/8/11.
#include <bits/stdc++.h>

#define lson (p<<1)
#define rson (p<<1|1)
using namespace std;
using pii=pair<int, int>;
using piii=pair<pair<int, int>, int>;
using ll=long long;
const int maxn=1e5+5;
/*线段树*/
int d[maxn<<2],wt[maxn],laz[maxn << 2];
ll mod;
void build(int s,int t,int p)
{
    if(s==t)
    {
        d[p]=wt[s]%mod;
        return;
    }
    int m=(s+t)>>1;
    build(s,m,lson),build(m+1,t,rson);
    d[p]=(d[lson]+d[rson])%mod;
}
void pushdown(int s,int t,int p)
{
    int m=(s+t)>>1;
    d[lson]=(d[lson]+(m-s+1)*laz[p])%mod,d[rson]=(d[rson]+(t-m)*laz[p])%mod;
    laz[lson]=(laz[lson]+laz[p])%mod,laz[rson]=(laz[rson]+laz[p])%mod;
    laz[p]=0;
}
void update(int l,int r,int s,int t,int c,int p)
{
    if(l<=s&&t<=r)
    {
        d[p]+=c*(t-s+1);
        laz[p]+=c;
        return ;
    }
    if(laz[p]) pushdown(s,t,p);
    int m=(s+t)>>1;
    if(l<=m) update(l,r,s,m,c,lson);
    if(r>m) update(l,r,m+1,t,c,rson);
    d[p]=(d[rson]+d[lson])%mod;
}
ll getsum(int l,int r,int s,int t,int p)
{
    if(l<=s&&t<=r) return d[p]%mod;
    if(laz[p]) pushdown(s,t,p);
    int m=(s+t)>>1;
    ll sum=0;
    if(l<=m) sum=(sum+getsum(l,r,s,m,lson))%mod;
    if(r>m) sum=(sum+getsum(l,r,m+1,t,rson))%mod;
    return sum%mod;
}
int cnt=0,head[maxn<<1];
struct edge{
    int to,next;
}e[maxn<<1];
void add(int u,int v)
{
    e[++cnt].to=v;
    e[cnt].next=head[u];
    head[u]=cnt;
}
int dep[maxn],f[maxn],siz[maxn],son[maxn];
int top[maxn],w[maxn<<1],id[maxn],tot=0,n;
void dfs1(int x,int fa,int deep)
{
    dep[x]=deep,f[x]=fa,siz[x]=1;
    int maxson=-1;
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        dfs1(v,x,deep+1);
        siz[x]+=siz[v];
        if(siz[v]>maxson) maxson=siz[v],son[x]=v;
    }
}
void dfs2(int u,int topf)
{
    id[u]=++tot;
    wt[tot]=w[u];
    top[u]=topf;
    if(!son[u]) return;
    dfs2(son[u],topf);
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
ll qrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans=(ans+getsum(id[top[x]],id[x],1,n,1)%mod)%mod;
        x=f[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans=(ans+getsum(id[x],id[y],1,n,1)%mod)%mod;
    return ans;
}
void updrange(int x,int y,int k)
{
    k%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        update(id[top[x]],id[x],1,n,k,1);
        x=f[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    update(id[x],id[y],1,n,k,1);
}
ll qson(int x)
{
    return getsum(id[x],id[x]+siz[x]-1,1,n,1)%mod;
}
void updson(int x,int k)
{
    update(id[x],id[x]+siz[x]-1,1,n,k,1);
}
ll m,r;
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>m>>r>>mod;
    for(int i=1;i<=n;++i) cin>>w[i];
    for(int i=1,u,v;i<n;++i) cin>>u>>v,add(u,v),add(v,u);
    dfs1(r,0,1);
    dfs2(r,r);
    build(1,n,1);
    while(m--)
    {
        int k,x,y,z;
        cin>>k;
        if(k==1)
            cin>>x>>y>>z,updrange(x,y,z);
        else if(k==2)
            cin>>x>>y,cout<<qrange(x,y)<<endl;
        else if(k==3)
            cin>>x>>y,updson(x,y);
        else if(k==4)
            cin>>x,cout<<qson(x)<<endl;
    }
    return 0;
}