天天看點

學習筆記::樹上莫隊

王室聯邦:樹分塊,參見popoqqq大神的部落格,講得很詳細

莫隊:小z的襪子

學習筆記::樹上莫隊
學習筆記::樹上莫隊

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define N 200010
struct edge
{
    int to,nxt;
}e[N];
struct data
{
    int u,v,a,b,id;
}q[N];
int n,m,tot,cnt=1,Time,ans,top,size;
int dfn[N],belong[N],head[N],dep[N],used[N],c[N],p[N],s[N];
int fa[30][N],answer[N];
void link(int u,int v)
{
    e[++cnt].nxt=head[u];
    head[u]=cnt;
    e[cnt].to=v;
}
bool cp(data x,data y)
{
    if(belong[x.u]!=belong[x.v]) return belong[x.u]<belong[x.v];
    return dfn[x.u]<dfn[x.v];
}
void reverse(int u)
{
    if(!used[u]) 
    {
        used[u]=1; p[c[u]]++; if(p[c[u]]==1) ans++;
    }
    else
    {
        used[u]=0; p[c[u]]--; if(!p[c[u]]) ans--;
    }
}
void dfs(int u,int last)
{
    int bottom=top+1; dfn[u]=++Time;
    for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=last)
    {
        int v=e[i].to;
        dep[v]=dep[u]+1; fa[0][v]=u;
        dfs(v,u);        
        if(top-bottom+1>=size)
        {
            ++tot; ++top;
            while(top>=bottom) belong[s[--top]]=tot;
        }
    }
    s[++top]=u;
}
void solve(int u,int v)
{
    while(u!=v) 
        if(dep[u]>dep[v])
        {
            reverse(u); u=fa[0][u];
        }  
        else
        {
            reverse(v); v=fa[0][v];
        } 
}
void init()
{
    for(int i=1;i<=22;i++)
        for(int j=1;j<=n;j++) if(fa[i-1][j]!=-1) fa[i][j]=fa[i-1][fa[i-1][j]];
}
int lca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    int temp=dep[u]-dep[v];
    for(int i=22;i>=0;i--) 
        if(temp&(1<<i)) u=fa[i][u];
    if(u==v) return u;
    for(int i=22;i>=0;i--)
    {
        if(fa[i][u]!=fa[i][v]) 
        {
            u=fa[i][u];
            v=fa[i][v];
        }
    }
    return fa[0][u];
}
int main()
{
    memset(fa,-1,sizeof(fa));
    scanf("%d%d",&n,&m);
    size=(int)(sqrt(n));
    int root=1;
    for(int i=1;i<=n;i++) scanf("%d",&c[i]);
    for(int i=1;i<=n;i++)
    {
        int u,v; scanf("%d%d",&u,&v);
        if(!u) root=v; else if(!v) root=u;
        else 
        {
            link(u,v); link(v,u);
        }
    }
    dfs(root,0);
    init();
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d%d",&q[i].u,&q[i].v,&q[i].a,&q[i].b);
        q[i].id=i;
    }    
    sort(q+1,q+m+1,cp);
    solve(q[1].u,q[1].v);
    int x=lca(q[1].u,q[1].v);
    reverse(x);
    answer[q[1].id]=ans;
    if(p[q[1].a]&&p[q[1].b]&&q[1].a!=q[1].b) answer[q[1].id]--;
    reverse(x);
    for(int i=2;i<=m;i++)
    {
        solve(q[i-1].u,q[i].u);
        solve(q[i-1].v,q[i].v);
        int x=lca(q[i].u,q[i].v);
        reverse(x);
        answer[q[i].id]=ans;
        if(p[q[i].a]&&p[q[i].b]&&q[i].a!=q[i].b) answer[q[i].id]--;
        reverse(x);
    }
    for(int i=1;i<=m;i++) printf("%d\n",answer[i]);
    return 0;
}      

View Code