天天看點

JZOJ 7036. 2021.03.30【2021省賽模拟】淩亂平衡樹(平衡樹單旋+權值線段樹)JZOJ 7036. 2021.03.30【2021省賽模拟】淩亂平衡樹

JZOJ 7036. 2021.03.30【2021省賽模拟】淩亂平衡樹

題目大意

  • 給出兩棵Treap,大小分别為 n , m n,m n,m,每個點的 p r i o r i t y priority priority值為子樹大小(是以滿足大根堆性質), Q Q Q次修改(修改是永久的),每次單旋一個節點,求修改前和每次修改後後兩樹合并之後的所有節點深度之和。合并按照Treap的合并方式,左樹根為 x x x,右樹根為 y y y時,當 s i z e x ≥ s i z e y size_x\ge size_y sizex​≥sizey​時以 x x x為根,否則反之。
  • 1 ≤ n , m , Q ≤ 2 ∗ 1 0 5 1\le n,m,Q\le2*10^5 1≤n,m,Q≤2∗105

題解

  • 考慮合并的過程,記錄目前深度 d p dp dp,左樹根每次向右走,就加上左兒子 F + G ∗ d p F+G*dp F+G∗dp,含義是所有點到左子樹根的深度加上到實際的根深度內插補點。右邊同理。這樣需要在每次單旋後重新計算每個子樹的大小 G G G及以該子樹根為根的深度和 F F F。 G G G可以在常數複雜度内維護,但 F F F不行。
  • 換一種思路,記錄總的深度和 s u m sum sum,每次求出合并後增加的內插補點。這樣合并的過程中,左樹根每次向右走,則加上右樹根的 G G G,含義是它子樹内所有點的深度都會被增加 1 1 1。右邊同理。
  • 而合并時左樹根始終向右,右樹根始終向左,其它的節點是不會經過的,且與它相關的值也不會調用到,是以可以把左根向右和右根向左兩條鍊(以下稱為鍊)單獨看,設鍊上 G G G序列左邊依次為 A A A,右邊為 B B B。 A i A_i Ai​對答案的貢獻次數為 ( A i , A i − 1 ] (A_i,A_{i-1}] (Ai​,Ai−1​]中 B B B的個數, B i B_i Bi​對答案的貢獻次數為 [ B i , B i − 1 ) [B_i,B_{i-1}) [Bi​,Bi−1​)中 A A A的個數,注意這裡區間的開閉情況。
  • 那麼可以用權值線段樹維護,把初始的 A A A和 B B B都存進同一棵權值線段樹中,在單旋時進行修改。
  • 隻有兩種情況需要修改:
  • 1、單旋的節點 x x x和 x x x的父親都在鍊中;
  • 2、單旋的節點 x x x不在鍊中, x x x的父親在鍊中。
  • 修改時因為 G G G值會改變,是以需要先删除該點及其貢獻,修改完 G G G後再加入回來。修改的貢獻不僅有它自己的貢獻,還有 A A A和 B B B中它們前驅的貢獻。

代碼

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 200010
#define ll long long
struct {
	int p[2];
}f[N * 4];
int ns;
ll ans;
void is(int v, int l, int r, int x, int o, int c) {
	if(l == r) {
		f[v].p[o] += c;
	}
	else {
		int mid = (l + r) / 2;
		if(x <= mid) is(v * 2, l, mid, x, o, c); else is(v * 2 + 1, mid + 1, r, x, o, c);
		f[v].p[o] = f[v * 2].p[o] + f[v * 2 + 1].p[o];
	}
}
int get(int v, int l, int r, int x, int y, int o) {
	if(x > y) return 0;
	if(l == x && r == y) return f[v].p[o];
	int mid = (l + r) / 2;
	if(y <= mid) return get(v * 2, l, mid, x, y, o);
	if(x > mid) return get(v * 2 + 1, mid + 1, r, x, y, o);
	return get(v * 2, l, mid, x, mid, o) + get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
}
int find(int v, int l, int r, int x, int y, int k, int o) {
	if(f[v].p[o] < k || x > y) return -1;
	if(l == r) return l;
	int mid = (l + r) / 2;
	if(y <= mid) return find(v * 2, l, mid, x, y, k, o);
	if(x > mid) return find(v * 2 + 1, mid + 1, r, x, y, k, o);
	int s = get(v * 2, l, mid, x, mid, o);
	if(s >= k) return find(v * 2, l, mid, x, mid, k, o);
	return find(v * 2 + 1, mid + 1, r, mid + 1, y, k - s, o);
}
int find0(int v, int l, int r, int x, int y, int k, int o) {
	if(f[v].p[o] < k || x > y) return -1;
	if(l == r) return l;
	int mid = (l + r) / 2;
	if(y <= mid) return find0(v * 2, l, mid, x, y, k, o);
	if(x > mid) return find0(v * 2 + 1, mid + 1, r, x, y, k, o);
	int s = get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
	if(s >= k) return find0(v * 2 + 1, mid + 1, r, mid + 1, y, k, o);
	return find0(v * 2, l, mid, x, mid, k - s, o);
}
ll count(int x, int o) {
	if(!o) {
		int t = find(1, 1, ns, x, ns, 2, 0);
		if(t == -1) t = ns;
		return (ll)get(1, 1, ns, x + 1, t, 1) * x;
	}
	else {
		int t = find(1, 1, ns, x, ns, 2, 1);
		if(t == -1) t = ns + 1;
		return (ll)get(1, 1, ns, x, t - 1, 0) * x;
	}
}
int fr(int x, int o) {
	if(!o) {
		int t = find0(1, 1, ns, 1, x, 1, 1);
		return t == -1 ? 0 : t;
	}
	else {
		int t = find0(1, 1, ns, 1, x - 1, 1, 0);
		return t == -1 ? 0 : t;
	}
}
struct {
	int s, rt, p[N];
	ll F[N], si[N], sum;
	struct {
		int s[2], fa, p;	
	}f[N];
	void ins(int r, int l, int i) {
		f[i].s[0] = l, f[i].s[1] = r;
		f[l].fa = f[r].fa = i;
		f[l].p = 0, f[r].p = 1;
	}
	void ro(int x, int o) {
		int y = f[x].fa, z = f[y].fa, py = f[x].p, pz = f[y].p;
		f[z].s[pz] = x, f[x].fa = z, f[x].p = pz;
		f[y].s[py] = f[x].s[py ^ 1], f[f[x].s[py ^ 1]].fa = y, f[f[x].s[py ^ 1]].p = py;
		f[x].s[py ^ 1] = y, f[y].fa = x, f[y].p = py ^ 1;
		if(rt == y) rt = x;
		int tp;
		if(p[y] && p[x]) {
			ans -= count(si[x], o) + count(si[y], o);
			ans -= fr(si[y], o) + fr(si[x], o);
			tp = find0(1, 1, ns, 1, si[x], 2, o);
			if(tp > 0) ans -= count(tp, o);
			is(1, 1, ns, si[y], o, -1);
			is(1, 1, ns, si[x], o, -1);
		}
		else if(p[y] && !p[x]) {
			ans -= count(si[y], o);
			ans -= fr(si[y], o);
			tp = find0(1, 1, ns, 1, si[y], 2, o);
			if(tp > 0) ans -= count(tp, o);
			is(1, 1, ns, si[y], o, -1);
		}
		
		si[y] = si[f[y].s[0]] + si[f[y].s[1]] + 1;
		si[x] = si[f[x].s[0]] + si[f[x].s[1]] + 1;
		sum += si[f[y].s[py ^ 1]] - si[f[x].s[py]];
		
		if(p[y] && p[x]) {
			is(1, 1, ns, si[x], o, 1);
			ans += fr(si[x], o);
			ans += count(si[x], o);
			if(tp > 0) ans += count(tp, o);
			p[y] = 0;
		}
		else if(p[y] && !p[x]) {
			is(1, 1, ns, si[x], o, 1);
			is(1, 1, ns, si[y], o, 1);
			ans += fr(si[y], o) + fr(si[x], o);
			ans += count(si[y], o) + count(si[x], o);
			if(tp > 0) ans += count(tp, o);
			p[x] = 1;
		}
	}
	int find() {
		for(int i = 1; i <= s; i++) if(f[i].fa == 0) return i;
	}
	void dfs(int k) {
		F[k] = 1, si[k] = 1;
		if(f[k].s[0]) dfs(f[k].s[0]), si[k] += si[f[k].s[0]], F[k] += F[f[k].s[0]] + si[f[k].s[0]];
		if(f[k].s[1]) dfs(f[k].s[1]), si[k] += si[f[k].s[1]], F[k] += F[f[k].s[1]] + si[f[k].s[1]];
	}
}a, b;
void solve() {
	int x = a.rt;
	while(x) is(1, 1, ns, a.si[x], 0, 1), a.p[x] = 1, x = a.f[x].s[1];
	x = b.rt;
	while(x) is(1, 1, ns, b.si[x], 1, 1), b.p[x] = 1, x = b.f[x].s[0];
	ans = 0;
	x = a.rt;
	while(x) ans += count(a.si[x], 0) ,x = a.f[x].s[1];
	x = b.rt;
	while(x) ans += count(b.si[x], 1), x = b.f[x].s[0];
	printf("%lld\n", ans + a.sum + b.sum);
}
int read() {
	int s = 0;
	char x = getchar();
	while(x < '0' || x > '9') x = getchar();
	while(x >= '0' && x <= '9') s = s * 10 + x - 48, x = getchar();
	return s;
}
int main() {
	int Q, i;
	scanf("%d%d", &a.s, &b.s);
	for(i = 1; i <= a.s; i++) {
		a.ins(read(), read(), i);
	}
	for(i = 1; i <= b.s; i++) {
		b.ins(read(), read(), i);
	}
	ns = max(a.s, b.s) + 1;
	a.rt = a.find(), b.rt = b.find(); 
	a.dfs(a.rt), b.dfs(b.rt);
	a.sum = a.F[a.rt], b.sum = b.F[b.rt];
	scanf("%d", &Q);
	solve();
	while(Q--) {
		if(read() == 1) a.ro(read(), 0); else b.ro(read(), 1);
		printf("%lld\n", ans + a.sum + b.sum);
	}
	return 0;
}
           

自我小結

  • 細節比較多,各條語句中的順序很重要,需要理清楚。