SP10707 COT2 - Count on a tree II
參考:樹上莫隊
樹上莫隊和普通的莫隊差不多,隻是把區間從普通的數組,轉到歐拉序上(其實也就是括号序)
該問題求解的是 x,y 兩個點之間的最短路徑上,有多少個不同顔色的點
對于這個問題,分兩種情況讨論
- \(lca(x,y)=x\ or\ y\),我們隻需要記錄\([b[x],b[y]]\),這段區間的貢獻即可,(預設\(b[x]<b[y]\),不然交換)
- 若不成立,則記錄\(e[x],b[y]\)的貢獻
b[i] 表示 i 這個點 dfs 時開始的時間,e[i] 表示 i 這個點 dfs 結束的時間。
需要注意的幾個點
- 塊的大小為\(\frac{n}{\sqrt{m}}\)時最佳,本題為\(\frac{2n}{\sqrt{m}}\)
- 在寫分塊,求 belo 數組的時候,不要寫假了,不然複雜度也是假的。
- 在比較兩個點 x,y 的先後次序的時候,要用 b[x] 和 b[y] 來進行比較,不能用 dep 來直接進行比較。
- dfs 取根的時候,可以用 rand 來随機取根,這樣可以防止出題人卡資料。
//Created by CAD
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
int b[maxn],e[maxn],id[maxn<<1],a[maxn];
int belo[maxn<<1];
vector<int> g[maxn];
struct query{
int l,r,x,y,LCA,id;
bool operator<(const query& q){
if(belo[l]!=belo[q.l]) return belo[l]<belo[q.l];
return (belo[l]&1)?r<q.r:r>q.r;
}
}q[maxn];
int fa[maxn][30],dep[maxn],lg[maxn];
int tin=0;
void dfs(int x,int o){
fa[x][0]=o,dep[x]=dep[o]+1;
for(int i=1;i<=lg[dep[x]];++i)
fa[x][i]=fa[fa[x][i-1]][i-1];
id[b[x]=++tin]=x;
for(int i:g[x])
if(i!=o) dfs(i,x);
id[e[x]=++tin]=x;
}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int k=lg[dep[x]]-1;k>=0;--k){
if(fa[x][k]!=fa[y][k])
x=fa[x][k],y=fa[y][k];
}
return fa[x][0];
}
int now=0;
int ans[maxn],f[maxn],cnt[maxn];
inline void modify(int &x){
int num=a[x];
if(f[x]){
cnt[num]--;
if(!cnt[num]) now--;
}
else{
if(!cnt[num]) now++;
cnt[num]++;
}
f[x]^=1;
}
unordered_map<int,int> vis;
int main(){
int n,m;
scanf("%d%d",&n,&m);
int blo=double(2*n)/sqrt(m);
for(int i=1;i<=n;++i){
scanf("%d",a+i);
if(!vis.count(a[i]))
vis[a[i]]=vis.size();
a[i]=vis[a[i]];
belo[i]=(i-1)/blo+1;
belo[i*2]=(i*2-1)/blo+1;
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=m;++i){
int l,r;
scanf("%d%d",&l,&r);
if(b[l]>b[r])
swap(l,r);
int LCA=lca(l,r);
if(LCA==l) q[i]={b[l],b[r],l,r,LCA,i};
else q[i]={e[l],b[r],l,r,LCA,i};
}
sort(q+1,q+m+1);
int l=1,r=0;
for(int i=1;i<=m;++i){
int ql=q[i].l,qr=q[i].r,x=q[i].x,y=q[i].y,LCA=q[i].LCA;
while(l<ql) modify(id[l++]);
while(l>ql) modify(id[--l]);
while(r<qr) modify(id[++r]);
while(r>qr) modify(id[r--]);
int bj=0;
if(!f[LCA]) modify(LCA),bj=1;
ans[q[i].id]=now;
if(bj) modify(LCA);
}
for(int i=1;i<=m;++i)
printf("%d\n",ans[i]);
}