可持久化并查集
題目連結:ybt金牌導航4-6-4 / luogu P3402
題目大意
要你支援可持久化的并查集。
即可以退回到第 k 次操作後的并查集。
思路
你考慮并查集的過程。
首先是合并,其實就是找到它們所在的集合,然後把一個集合的父親連向另一個。
即 f a fa fa 數組發生了改變,而我們要維護的既是 f a fa fa 數組的可持久化。
那其實就是要維護可持續化數組,那主席樹來維護即可。
然後你也可以看出,那你就不能路徑壓縮。
那複雜都就會不優,我們考慮按秩合并,即把小的連向大的。(有點啟發式合并的感覺)
那盡可能的讓大小一樣,鍊的深度就是 l o g n logn logn。
那 f i n d find find 函數還是一樣,你就不斷往上跳,直到它的父親是它自己。
找它的父親要在主席樹上找, l o g n logn logn,然後跳是跳鍊深度次數, l o g n logn logn,是以複雜度是 O ( m l o g 2 n ) O(mlog^2n) O(mlog2n)。
代碼
#include<cstdio>
#include<algorithm>
using namespace std;
int n, m, op, x, y, rt[200001], tot;
int fa[200001 << 5], ls[200001 << 5], rs[200001 << 5], deg[200001 << 5];
void build(int &now, int l, int r) {
now = ++tot;
if (l == r) {
fa[now] = l;//一開始獨立,自己父親是自己
return ;
}
int mid = (l + r) >> 1;
build(ls[now], l, mid);
build(rs[now], mid + 1, r);
}
int query(int now, int l, int r, int pl) {
if (l == r) return now;
int mid = (l + r) >> 1;
if (pl <= mid) return query(ls[now], l, mid, pl);
else return query(rs[now], mid + 1, r, pl);
}
int find(int root, int pl) {
int now = query(root, 1, n, pl);
if (fa[now] == pl) return now;
return find(root, fa[now]);
}
int merge(int bef, int l, int r, int X, int Y) {
int now = ++tot;//記得新開點
ls[now] = ls[bef];
rs[now] = rs[bef];
if (l == r) {
fa[now] = Y;//合并
deg[now] = deg[bef];
return now;
}
int mid = (l + r) >> 1;
if (X <= mid) ls[now] = merge(ls[bef], l, mid, X, Y);
else rs[now] = merge(rs[bef], mid + 1, r, X, Y);
return now;
}
void adddeg(int now, int l, int r, int pl) {
if (l == r) {
deg[now]++;
return ;
}
int mid = (l + r) >> 1;
if (pl <= mid) adddeg(ls[now], l, mid, pl);
else adddeg(rs[now], mid + 1, r, pl);
}
int main() {
scanf("%d %d", &n, &m);
build(rt[0], 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d", &op);
if (op == 1) {
rt[i] = rt[i - 1];
scanf("%d %d", &x, &y);
int X = find(rt[i], x), Y = find(rt[i], y);
if (fa[X] == fa[Y]) continue;
if (deg[X] > deg[Y]) swap(X, Y);//讓深度小的連向深度大的
rt[i] = merge(rt[i - 1], 1, n, fa[X], fa[Y]);
if (deg[X] == deg[Y]) adddeg(rt[i], 1, n, fa[Y]);//如果兩個深度都一樣,就一定要加深度了
continue;
}
if (op == 2) {
scanf("%d", &x);
rt[i] = rt[x];
continue;
}
if (op == 3) {
rt[i] = rt[i - 1];
scanf("%d %d", &x, &y);
int X = find(rt[i], x), Y = find(rt[i], y);
if (fa[X] == fa[Y]) printf("1\n");
else printf("0\n");
continue;
}
}
return 0;
}