天天看點

樹鍊剖分 [ZJOI2008]樹的統計

前置芝士

dfs序,線段樹

正文

樹鍊剖分就是通過劃分輕重邊将樹分割成許多鍊,然後利用資料結構(線段樹)來維護這些鍊

使得在樹上可以用非常優秀的複雜度去周遊一些資訊

(本質上是一種優化暴力(就像LCA)(其實所有資料結構都是優化的暴力))

樹鍊剖分 [ZJOI2008]樹的統計
樹鍊剖分 [ZJOI2008]樹的統計

看一個模闆題

首先明确的概念

重兒子:父親節點的所有兒子中子樹結點數目最多(size最大)的結點;(節點數目包括自身)

輕兒子:父親節點中除了重兒子以外的兒子;

重邊:父親結點和重兒子連成的邊;

輕邊:父親節點和輕兒子連成的邊;

重鍊:由多條重邊連接配接而成的路徑;

輕鍊:由多條輕邊連接配接而成的路徑;

樹鍊剖分 [ZJOI2008]樹的統計

比如上面這幅圖中,用黑線連接配接的結點都是重結點,其餘均是輕結點,

2-11就是重鍊,2-5就是輕鍊,用紅點标記的就是該結點所在重鍊的起點,也就是下文提到的top結點,

還有每條邊的值其實是進行dfs時的執行序号。

樹鍊剖分的思路

将一棵每個節點的兒子按照兒子大小劃分成重兒子和輕兒子(其他兒子),将樹劃分成一條條鍊(重鍊和輕鍊)

利用dfs序,将同一個鍊上的點放在一起,在建出線段樹

使得在調用兩點的簡單路徑時,可以一跳跳過多個節點(類比LCA思考)

進而達到減小複雜度的目的

如何實作

有不了解的地方,手模是個不錯的選擇

1、先跑第一遍DFS(初始化)

每周遊到一個點,讓siz為1,記錄父親與深度

然後回溯的時候加上其子樹的點的大小

并順便在周遊到的子樹中挑出重兒子

1 void dfs(int x, int fa){//
 2         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;//确定以x為根的子樹的大小,父親,深度 
 3         //cout<<cnt<<"lzx"<<x<<" "<<fa<<endl;
 4         for(int i = head[x]; i; i = e[i].nxt){//類似于lca初始化的周遊 
 5             int v = e[i].to;
 6             if(v == fa) continue;
 7             dfs(v, x);
 8             siz[x] += siz[v];//回溯的時候更新子樹大小 
 9             if(siz[son[x]] < siz[v]) son[x] = v;//挑出重兒子 
10         } 
11     }      

2、在跑一遍DFS(分鍊)

确定dfs序,并把dfs序所對應的元素用pre數組存起來

注意周遊順序,因為開始我們提到劃分重鍊,是以我們要優先周遊重兒子,并把鍊頂元素也傳下去(先周遊重兒子感覺珂以使複雜度最優)

周遊完重兒子後,再周遊其他兒子,并新開一條鍊

1 void dfs2(int x, int tp){//分鍊,tp表示該鍊的頂端 
 2         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;//确定x節點的鍊的頂端是tp,x的dfs序及反dfs序 
 3         if(son[x]) dfs2(son[x], tp);//為了使重鍊的dfn在一起,要先周遊重兒子 
 4         for(int i = head[x]; i; i = e[i].nxt){
 5             int v = e[i].to;
 6             if(v == fath[x] || son[x] == v) continue;//如果一個點的fa等于自己或者下一個點是它的重兒子就跳過
 7             //(如果是重兒子的話應該在以前就已經周遊了,是以還有防止在周遊一遍的作用 
 8             dfs2(v, v);//新開一條鍊 
 9         }
10     }      

3、資料維護

我們不難發現,每個重鍊的dfs序是連在一起的,那麼我們是不是可以考慮用線段樹來維護它,因為線段樹剛好可以維護一段連續的區間

線段樹闆中闆

樹鍊剖分 [ZJOI2008]樹的統計
樹鍊剖分 [ZJOI2008]樹的統計
1 #define lson i << 1
 2     #define rson i << 1 | 1
 3     struct Tree{//和,懶标記,長度 
 4         int sum, lazy, len;
 5     }tree[MAXN << 2];
 6     void push_up(int i){//上傳标記 
 7         tree[i].sum = (tree[lson].sum + tree[rson].sum) % p;
 8         return ;
 9     }
10     void build(int i, int l , int r){//建樹 
11         tree[i].lazy = 0, tree[i].len = r - l + 1;
12         if(l == r) {    
13             tree[i].sum = a[pre[l]] % p;
14             return ;
15         }
16         int mid = l + r >> 1;
17         build(lson, l, mid), build(rson, mid + 1, r);
18         push_up(i);
19         return ;
20     }
21     void pushdown(int i){//下傳懶标記 
22         if(tree[i].lazy){
23             tree[lson].lazy = (tree[lson].lazy + tree[i].lazy) % p;
24             tree[rson].lazy = (tree[rson].lazy + tree[i].lazy) % p;
25             tree[lson].sum = (tree[lson].sum + tree[i].lazy * tree[lson].len) % p;
26             tree[rson].sum = (tree[rson].sum + tree[i].lazy * tree[rson].len) % p;
27             tree[i].lazy = 0;
28         }
29         return ;
30     }
31     void add(int i, int l, int r, int L, int R, int k){
32     //lr表示周遊到的區間,LR表示查詢到的區間 
33         if(L <= l && r <= R) {
34             tree[i].sum = (tree[i].sum + (k * tree[i].len) % p) % p;
35             tree[i].lazy += k;
36             return ;
37         }
38         //cout<<l<<" "<<R << " "<< r<< " "<<L<<"lkp"<<endl;
39         if(l > R || r < L) return ;
40         pushdown(i);
41         int mid = (l + r) >> 1;
42         if(L <= mid) add(lson, l, mid, L, R, k);
43         if(R > mid) add(rson, mid + 1, r, L, R, k);
44         push_up(i);
45         return ;
46     }
47     int get(int i, int l, int r, int L, int R){
48         int sum = 0;
49         if(L <= l && r <= R) {
50             return tree[i].sum % p;
51         }
52         if(l > R || r < L) return 0;
53         pushdown(i);
54         int mid = (l + r) >> 1;
55         if(mid >= L) sum = (sum + get(lson, l, mid, L, R)) % p;
56         if(mid < R) sum = (sum + get(rson, mid + 1, r, L, R)) % p;
57         return sum % p;
58     }      

View Code

那麼怎麼更改資訊呢

(更改方式有點像倍增求LCA,珂以類比了解)

如果兩個元素不在同一條鍊上,

将鍊頂深的元素一直向上跳,并線上段樹中進行修改(提取)資訊的操作

如果兩個元素在同一條鍊上,直接進行修改(提取)資訊的操作

1 void change(int x, int y, int k){
 2         while (top[x] != top[y]){//如果兩個點的鍊頂不相同(感覺和LCA的處理有點類似 
 3             if(dep[top[x]] < dep[top[y]]) swap(x, y); 
 4             Seg::add(1, 1, n, dfn[top[x]], dfn[x], k);//先改變深度淺的 
 5             x = fath[top[x]];//向上跳到鍊頂的父親 
 6         }
 7         if(dfn[x] > dfn[y]) swap(x, y);//最後肯定是在一條鍊上 
 8         Seg::add(1, 1, n, dfn[x], dfn[y], k);
 9         return ;
10     }
11     int ask(int x, int y){
12         int ans = 0;
13         while(top[x] != top[y]){//道理和change函數類似 
14             if(dep[top[x]] < dep[top[y]]) swap(x, y);//先跳深度深度 
15             ans = (ans + Seg::get(1, 1, n, dfn[top[x]], dfn[x])) % p;
16             x = fath[top[x]];
17         }
18         if(dfn[x] > dfn[y]) swap(x, y);
19         ans = (ans + Seg::get(1, 1, n, dfn[x], dfn[y])) % p;
20         return ans % p;
21     }      

例題的AC代碼

namespace相當于把一部分函數進行組合包裝,珂以有效區分函數作用,并避免重變量名

調用的時候和std類似,用***::即可

1 /*
  2 Work by: Suzt_ilymics
  3 Knowledge: 樹鍊剖分 
  4 Time: O(nlog^2n)
  5 */
  6 #include<iostream>
  7 #include<cstdio>
  8 #define int long long
  9 using namespace std;
 10 const int MAXN = 1e5+5;
 11 int n, m, r, p;
 12 int a[MAXN], pre[MAXN], siz[MAXN], son[MAXN], dep[MAXN], fath[MAXN], top[MAXN], dfn[MAXN];
 13 
 14 int read(){//因一個逗号寫挂了的快讀 
 15     /*int s=0,w=1;
 16        char ch=getchar();
 17       while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
 18        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
 19        return s*w;
 20    */
 21     int s = 0, w = 1;
 22     char ch = getchar();
 23     //while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
 24     while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
 25     while(ch >= '0' && ch <= '9') 
 26     s = s * 10 + ch - '0', ch = getchar();
 27     return s * w;
 28 }
 29 
 30 namespace Seg{//線段樹闆中闆 
 31     #define lson i << 1
 32     #define rson i << 1 | 1
 33     struct Tree{//和,懶标記,長度 
 34         int sum, lazy, len;
 35     }tree[MAXN << 2];
 36     void push_up(int i){//上傳标記 
 37         tree[i].sum = (tree[lson].sum + tree[rson].sum) % p;
 38         return ;
 39     }
 40     void build(int i, int l , int r){//建樹 
 41         tree[i].lazy = 0, tree[i].len = r - l + 1;
 42         if(l == r) {    
 43             tree[i].sum = a[pre[l]] % p;
 44             return ;
 45         }
 46         int mid = l + r >> 1;
 47         build(lson, l, mid), build(rson, mid + 1, r);
 48         push_up(i);
 49         return ;
 50     }
 51     void pushdown(int i){//下傳懶标記 
 52         if(tree[i].lazy){
 53             tree[lson].lazy = (tree[lson].lazy + tree[i].lazy) % p;
 54             tree[rson].lazy = (tree[rson].lazy + tree[i].lazy) % p;
 55             tree[lson].sum = (tree[lson].sum + tree[i].lazy * tree[lson].len) % p;
 56             tree[rson].sum = (tree[rson].sum + tree[i].lazy * tree[rson].len) % p;
 57             tree[i].lazy = 0;
 58         }
 59         return ;
 60     }
 61     void add(int i, int l, int r, int L, int R, int k){
 62     //lr表示周遊到的區間,LR表示查詢到的區間 
 63         if(L <= l && r <= R) {
 64             tree[i].sum = (tree[i].sum + (k * tree[i].len) % p) % p;
 65             tree[i].lazy += k;
 66             return ;
 67         }
 68         //cout<<l<<" "<<R << " "<< r<< " "<<L<<"lkp"<<endl;
 69         if(l > R || r < L) return ;
 70         pushdown(i);
 71         int mid = (l + r) >> 1;
 72         if(L <= mid) add(lson, l, mid, L, R, k);
 73         if(R > mid) add(rson, mid + 1, r, L, R, k);
 74         push_up(i);
 75         return ;
 76     }
 77     int get(int i, int l, int r, int L, int R){
 78         int sum = 0;
 79         if(L <= l && r <= R) {
 80             return tree[i].sum % p;
 81         }
 82         if(l > R || r < L) return 0;
 83         pushdown(i);
 84         int mid = (l + r) >> 1;
 85         if(mid >= L) sum = (sum + get(lson, l, mid, L, R)) % p;
 86         if(mid < R) sum = (sum + get(rson, mid + 1, r, L, R)) % p;
 87         return sum % p;
 88     }
 89 }
 90 
 91 namespace Cut{
 92     int num_edge = 0, cnt = 0, head[MAXN << 1] = {0};
 93     struct edge{
 94         int nxt, to, from;
 95     }e[MAXN << 1];
 96     void add(int from, int to){ 
 97         e[++num_edge].to = to;
 98         e[num_edge].from = from;
 99         e[num_edge].nxt = head[from];
100         head[from] = num_edge;
101     }
102     void dfs(int x, int fa){//
103         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;//确定以x為根的子樹的大小,父親,深度 
104         //cout<<cnt<<"lzx"<<x<<" "<<fa<<endl;
105         for(int i = head[x]; i; i = e[i].nxt){//類似于lca初始化的周遊 
106             int v = e[i].to;
107             if(v == fa) continue;
108             dfs(v, x);
109             siz[x] += siz[v];//回溯的時候更新子樹大小 
110             if(siz[son[x]] < siz[v]) son[x] = v;//挑出重兒子 
111         } 
112     }
113     //引入重鍊這個概念會使分的鍊最少,複雜度更優秀 
114     void dfs2(int x, int tp){//分鍊,tp表示該鍊的頂端 
115         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;//确定x節點的鍊的頂端是tp,x的dfs序及反dfs序 
116         if(son[x]) dfs2(son[x], tp);//為了使重鍊的dfn在一起,要先周遊重兒子 
117         for(int i = head[x]; i; i = e[i].nxt){
118             int v = e[i].to;
119             if(v == fath[x] || son[x] == v) continue;//如果一個點的fa等于自己或者下一個點是它的重兒子就跳過
120             //(如果是重兒子的話應該在以前就已經周遊了,是以還有防止在周遊一遍的作用 
121             dfs2(v, v);//新開一條鍊 
122         }
123     }
124     void change(int x, int y, int k){
125         while (top[x] != top[y]){//如果兩個點的鍊頂不相同(感覺和LCA的處理有點類似 
126             if(dep[top[x]] < dep[top[y]]) swap(x, y); 
127             Seg::add(1, 1, n, dfn[top[x]], dfn[x], k);//先改變深度深的 
128             x = fath[top[x]];//向上跳到鍊頂的父親 
129         }
130         if(dfn[x] > dfn[y]) swap(x, y);//最後肯定是在一條鍊上 
131         Seg::add(1, 1, n, dfn[x], dfn[y], k);
132         return ;
133     }
134     int ask(int x, int y){
135         int ans = 0;
136         while(top[x] != top[y]){//道理和change函數類似 
137             if(dep[top[x]] < dep[top[y]]) swap(x, y);//先跳深度深的 
138             ans = (ans + Seg::get(1, 1, n, dfn[top[x]], dfn[x])) % p;
139             x = fath[top[x]];
140         }
141         if(dfn[x] > dfn[y]) swap(x, y);
142         ans = (ans + Seg::get(1, 1, n, dfn[x], dfn[y])) % p;
143         return ans % p;
144     }
145 }
146 
147 signed main()
148 {
149     //輸入 
150     n = read(), m = read(), r = read(), p = read();
151     for(int i = 1; i <= n; ++i) a[i] = read();
152     for(int i = 1, u, v; i <= n - 1; ++i) {
153         u = read(), v = read();
154     //cout<<"bilibili";
155         Cut::add(u, v), Cut::add(v, u);
156     }
157     //for(int i = 1; i <= Cut::num_edge; ++i)    printf("%d %dwzd\n", Cut::e[i].from, Cut::e[i].to);
158     //初始化 
159     Cut::dfs(r,0), Cut::dfs2(r, r), Seg::build(1, 1, n);
160     //操作 
161     for(int i = 1, opt, x, y, k; i <= m; ++i){
162         opt = read();
163         if(opt == 1){
164             x = read(), y = read(), k = read();
165             Cut::change(x, y, k);
166         }
167         if(opt == 2){
168             x = read(), y = read();
169             printf("%lld\n", Cut::ask(x, y));
170         }
171         if(opt == 3){
172             x = read(), k = read();
173             //cout<<dfn[x]<<" "<<siz[x]<<"zsf"<<endl;
174             Seg::add(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, k);
175         }
176         if(opt == 4){
177             x = read(); 
178             printf("%lld\n", Seg::get(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
179         }
180         
181     }
182     return 0;
183 }      

 [ZJOI2008]樹的統計

一個維護最大值的例題

自己犯得**錯誤:

看清資料範圍,送出時把檢驗用的cout删掉,max在push_up的時候隻需要取它兩個兒子的最大值

樹鍊剖分 [ZJOI2008]樹的統計
樹鍊剖分 [ZJOI2008]樹的統計
1 /*
  2 Work by: Suzt_ilymics
  3 Knowledge: 樹鍊剖分 
  4 Time: O(nlog^2n)
  5 */
  6 #include<iostream>
  7 #include<cstdio>
  8 #include<string>
  9 #include<cstdio>
 10 #define int long long
 11 using namespace std;
 12 const int inf = -1000000000;
 13 const int MAXN = 3e4+5;
 14 int n, m;
 15 string s;
 16 int a[MAXN], pre[MAXN], siz[MAXN], son[MAXN], dep[MAXN], fath[MAXN], top[MAXN], dfn[MAXN];
 17 
 18 int read(){
 19     int s = 0, w = 1;
 20     char ch = getchar();
 21     while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
 22     while(ch >= '0' && ch <= '9') 
 23     s = s * 10 + ch - '0', ch = getchar();
 24     return s * w;
 25 }
 26 
 27 namespace Seg{
 28     #define lson i << 1
 29     #define rson i << 1 | 1
 30     struct Tree{
 31         int sum, lazy, len, max;
 32     }tree[MAXN << 2];
 33     void push_up(int i){
 34         tree[i].sum = tree[lson].sum + tree[rson].sum;
 35         tree[i].max = max(tree[lson].max, tree[rson].max);
 36         return ;
 37     }
 38     void build(int i, int l , int r){
 39         tree[i].lazy = 0, tree[i].len = r - l + 1;
 40         if(l == r) {    
 41             tree[i].sum = a[pre[l]];
 42             tree[i].max = a[pre[l]];
 43             return ;
 44         }
 45         int mid = (l + r) >> 1;
 46         build(lson, l, mid), build(rson, mid + 1, r);
 47         push_up(i);
 48         return ;
 49     }
 50     void add(int i, int l, int r, int L, int R, int k){
 51         if(L <= l && r <= R) {
 52             tree[i].sum = k;
 53             tree[i].max = k;
 54             return ;
 55         }
 56         if(l > R || r < L) return ;
 57         int mid = (l + r) >> 1;
 58         if(L <= mid) add(lson, l, mid, L, R, k);
 59         if(R > mid) add(rson, mid + 1, r, L, R, k);
 60         push_up(i);
 61         return ;
 62     }
 63     int get_sum(int i, int l, int r, int L, int R){
 64         int sum = 0;
 65         if(L <= l && r <= R) {
 66             return tree[i].sum;
 67         }
 68         if(l > R || r < L) return 0;
 69         int mid = (l + r) >> 1;
 70         if(mid >= L) sum += get_sum(lson, l, mid, L, R);
 71         if(mid < R) sum += get_sum(rson, mid + 1, r, L, R);
 72         return sum;
 73     }
 74     int get_max(int i, int l, int r, int L, int R){
 75         int maxm = inf;
 76         if(L <= l && r <= R){
 77             return tree[i].max;
 78         }
 79         if(l > R || r < L) return inf;
 80         int mid = (l + r) >> 1;
 81         if(mid >= L) maxm = max (maxm, get_max(lson, l, mid, L, R));
 82         if(mid < R) maxm = max (maxm, get_max(rson, mid + 1, r, L, R));
 83         return maxm;
 84     }
 85 }
 86 
 87 namespace Cut{
 88     int num_edge = 0, cnt = 0, head[MAXN << 1] = {0};
 89     struct edge{
 90         int nxt, to, from;
 91     }e[MAXN << 1];
 92     void add(int from, int to){ 
 93         e[++num_edge].to = to;
 94         e[num_edge].from = from;
 95         e[num_edge].nxt = head[from];
 96         head[from] = num_edge;
 97     }
 98     void dfs(int x, int fa){//
 99         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;
100         for(int i = head[x]; i; i = e[i].nxt){
101             int v = e[i].to;
102             if(v == fa) continue;
103             dfs(v, x);
104             siz[x] += siz[v];
105             if(siz[son[x]] < siz[v]) son[x] = v;
106         } 
107     }
108     void dfs2(int x, int tp){
109         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;
110         if(son[x]) dfs2(son[x], tp);
111         for(int i = head[x]; i; i = e[i].nxt){
112             int v = e[i].to;
113             if(v == fath[x] || son[x] == v) continue;
114             dfs2(v, v);
115         }
116     }
117     int ask_sum(int x, int y){
118         int ans = 0;
119         while(top[x] != top[y]){
120             if(dep[top[x]] < dep[top[y]]) swap(x, y);
121             ans += Seg::get_sum(1, 1, n, dfn[top[x]], dfn[x]);
122             x = fath[top[x]];
123         }
124         if(dfn[x] > dfn[y]) swap(x, y);
125         ans += Seg::get_sum(1, 1, n, dfn[x], dfn[y]);
126         return ans;
127     }
128     int ask_max(int x, int y){
129         int maxm = inf;
130         while(top[x] != top[y]){
131             if(dep[top[x]] < dep[top[y]]) swap(x, y);
132             maxm = max (maxm, Seg::get_max(1, 1, n, dfn[top[x]], dfn[x]));
133             x = fath[top[x]];
134         }
135         if(dfn[x] > dfn[y]) swap(x, y);
136         maxm = max (maxm, Seg::get_max(1, 1, n, dfn[x], dfn[y]));
137         return maxm;
138     }
139 }
140 
141 signed main()
142 {
143     n = read();
144     for(int i = 1, u, v; i <= n - 1; ++i) {
145         u = read(), v = read();
146         Cut::add(u, v), Cut::add(v, u);
147     }
148     for(int i = 1; i <= n; ++i) a[i] = read();
149 
150     Cut::dfs(1,0), Cut::dfs2(1, 1), Seg::build(1, 1, n);
151     
152     m = read();
153     for(int i = 1, x, y, k; i <= m; ++i){
154         cin>>s;
155         if(s[1] == 'M'){//Qmax
156             x = read(), y = read();
157             if(x > y) swap(x, y);
158             printf("%lld\n", Cut::ask_max(x, y));
159         }
160         if(s[1] == 'H'){//Change
161             x = read(), k = read();
162             Seg::add(1, 1, n, dfn[x], dfn[x], k);
163         }
164         if(s[1] == 'S'){//Qsum
165             x = read(), y = read();
166             if(x > y) swap(x, y);
167             printf("%lld\n", Cut::ask_sum(x, y));
168         }
169     }
170     return 0;
171 }      

AC代碼