天天看點

樹上莫隊

在之前我已經發過了普通莫隊的部落格了。

傳送門

打了幾道莫隊的裸題後,我就學了一下樹上莫隊。

例題

這題的英文超好懂,我相信你的英語水準。

但我還是解釋一下吧。

題目大意:給你一棵NN個點的帶權的樹,有MM個詢問,詢問兩點之間不同的權值個數。

其中N≤10000,M≤400000N≤10000,M≤400000。

往區間的方向思考一下

處理樹上的資訊,有一個傳統的套路,就是将樹轉化成序列。

什麼dfs序,歐拉序,括号序。

那麼咋轉化成序列?

思考一下~

樹上莫隊

樹上莫隊用的是括号序。

不知道這是啥玩意?這就是dfs過程中,入棧和出棧各記一次的順序。顯然長度是N∗2N∗2的。

做法

對于uu和vv之間的路徑,假設uin<vinuin<vin。

分兩種情況:

1. LCA(u,v)=uinLCA(u,v)=uin時,可以詢問區間[uin,vin][uin,vin]。

2. 否則,詢問區間[uout,vin][uout,vin],并且另外加上LCA(u,v)LCA(u,v)的貢獻。

詢問區間是詢問區間中隻出現過一次的節點中的答案。

Why?

後面的事……

代碼

using namespace std;
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define MAXN 10000
#define MAXM 400000
int n,m;
int col[MAXN+1];
int *p[MAXN+1];
bool cmpp(int *x,int *y){
    return *x<*y;
}
struct EDGE{
    int to;
    EDGE *las;  
} e[MAXN*2+1];
int ne;
EDGE *last[MAXN+1];
void insert_edge(int u,int v){
    e[++ne]={v,last[u]};
    last[u]=e+ne;
}
int in[MAXN+1],out[MAXN+1],nowdfn;
int dy[MAXN*2+1];
int unit,be[MAXN*2+1];
int dep[MAXN+1];
int fa[MAXN+1][15];
void init(int);
int LCA(int,int);
struct Operation{
    int time,l,r,another;//如果要額外算LCA,another即為LCA的顔色,否則為0
} o[MAXM+1];
bool cmp(const Operation &x,const Operation &y){
    return be[x.l]<be[y.l] || be[x.l]==be[y.l] && x.r<y.r;
}
int num[MAXN+1];//表示某個顔色的出現次數
int gx[MAXN+1];//貢獻,表示是否隻出現一次(其實可以用bool數組)
int ans[MAXM+1];
int main(){
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;++i)
        scanf("%d",&col[i]),p[i]=col+i;
    sort(p+1,p+n+1,cmpp);
    for (int i=1,k=0,last=-2147483648;i<=n;++i){//離散化
        if (last!=*p[i])
            ++k,last=*p[i];
        *p[i]=k;
    }
    for (int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        insert_edge(u,v),insert_edge(v,u);
    }
    init(1);
    unit=sqrt(nowdfn);
    for (int i=1;i<=nowdfn;++i)
        be[i]=(i-1)/unit+1;
    for (int i=1;i<=m;++i){
        o[i].time=i;
        int u,v;
        scanf("%d%d",&u,&v);
        if (in[u]>in[v])
            swap(u,v);
        if (out[v]<out[u]){
            o[i].l=in[u];
            o[i].r=in[v];
            o[i].another=0;
        }
        else{
            o[i].l=out[u];
            o[i].r=in[v];
            o[i].another=col[LCA(u,v)];
        }
    }
    sort(o+1,o+m+1,cmp);
    int nowl=1,nowr=0,nowans=0;
    for (int i=1;i<=m;++i){
        while (nowr<o[i].r){
            nowr++;
            int lasnum=num[col[dy[nowr]]];
            num[col[dy[nowr]]]+=(gx[dy[nowr]]^1)-gx[dy[nowr]];
            gx[dy[nowr]]^=1;
            if (!lasnum && num[col[dy[nowr]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowr]]])
                nowans--;
        }
        while (nowl>o[i].l){
            nowl--;
            int lasnum=num[col[dy[nowl]]];
            num[col[dy[nowl]]]+=(gx[dy[nowl]]^1)-gx[dy[nowl]];
            gx[dy[nowl]]^=1;
            if (!lasnum && num[col[dy[nowl]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowl]]])
                nowans--;
        }
        while (nowr>o[i].r){
            int lasnum=num[col[dy[nowr]]];
            num[col[dy[nowr]]]+=(gx[dy[nowr]]^1)-gx[dy[nowr]];
            gx[dy[nowr]]^=1;
            if (!lasnum && num[col[dy[nowr]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowr]]])
                nowans--;
            nowr--;
        }
        while (nowl<o[i].l){
            int lasnum=num[col[dy[nowl]]];
            num[col[dy[nowl]]]+=(gx[dy[nowl]]^1)-gx[dy[nowl]];
            gx[dy[nowl]]^=1;
            if (!lasnum && num[col[dy[nowl]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowl]]])
                nowans--;
            nowl++;
        }   
        if (o[i].another)
            ans[o[i].time]=nowans+!num[o[i].another];
        else
            ans[o[i].time]=nowans;
    }
    for (int i=1;i<=m;++i)
        printf("%d\n",ans[i]);
    return 0;
}
void init(int x){
    in[x]=++nowdfn;
    dy[nowdfn]=x;
    dep[x]=dep[fa[x][0]]+1;
    for (int i=1;1<<i<dep[x];++i)
        fa[x][i]=fa[fa[x][i-1]][i-1];
    for (EDGE *ei=last[x];ei;ei=ei->las)
        if (ei->to!=fa[x][0])
            fa[ei->to][0]=x,init(ei->to);
    out[x]=++nowdfn;
    dy[nowdfn]=x;
}
int LCA(int u,int v){//這是利用倍增來求LCA的
    if (dep[u]<dep[v])
        swap(u,v);
    for (int k=dep[u]-dep[v],i=0;k;k>>=1,++i)
        u=fa[u][i];
    if (u==v)
        return u;
    for (int i=log2(dep[u]);i>=0;--i)
        if (fa[u][i]!=fa[v][i]){
            u=fa[u][i];
            v=fa[v][i];
        }
    return fa[u][0];
}