天天看點

【樹上倍增】最近公共祖先(LCA)

模版題:https://www.luogu.com.cn/problem/P3379

視訊講解1

視訊講解2

某佬題解

思路:

設目前節點為 x x x, f a [ x ] [ i ] fa[x][i] fa[x][i] 代表 x x x 的第 2 i 2^i 2i 個祖先節點,

顯然 f a [ x ] [ 0 ] fa[x][0] fa[x][0] 代表 x x x 的直接父節點,

而 f a [ x ] [ i ] fa[x][i] fa[x][i] = f a [ f a [ x ] [ i − 1 ] ] [ i − 1 ] fa[fa[x][i-1]][i-1] fa[fa[x][i−1]][i−1] , x x x 的 第 2 i 2^i 2i 個祖先節點是 x x x 的第 2 i − 1 2^{i-1} 2i−1 個祖先節點的第 2 i − 1 2^{i-1} 2i−1 個祖先節點( 2 i 2^i 2i = 2 i − 1 + 2 i − 1 2^{i-1} + 2^{i-1} 2i−1+2i−1)

d e p [ x ] dep[x] dep[x] 數組用來記錄 x x x 的深度

f a fa fa 數組需要用 d f s dfs dfs 來求解

void dfs(int u, int pre){

    dep[u] = dep[pre] + 1;
    fa[u][0] = pre;

    for (int i = 1; i <= (lg[dep[u]]); i++) {
        // u 的2^i個祖先是u的2^(i-1)個祖先的2^(i-1)個祖先
        // 2^i = 2^(i-1) + 2^(i-1)
        fa[u][i] = fa[fa[u][i-1]][i-1];
    }

    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if(j == pre) continue;
        dfs(j, u);
    }
} 
           

預處理

l o g 2 i log_2^i log2i​ 的值

for (int i = 0; i < N; ++i) {
    // lg 8 = lg7 + (1 << 3 == 8)
    lg[i] = lg[i-1] + (1 << (lg[i-1] + 1) == i);
}
           

求解 a a a 和 b b b 的 l c a lca lca 時,需要先 樹上倍增 讓 a a a 和 b b b 處于同一層,然後如果 a a a 和 b b b 是祖先關系的話,就傳回此時的 a a a 和 b b b 中任意的一個;

否則, a a a 和 b b b 同時向上倍增,找到 l c a lca lca 下面一層的節點,此時隻需傳回 f a [ a ] [ 0 ] fa[a][0] fa[a][0] 和 f a [ b ] [ 0 ] fa[b][0] fa[b][0] (即 a a a和 b b b任意一個的直接父節點)即可!

求證下面代碼為何一定能找出 a a a 和 b b b 的 l c a lca lca 的直接子節點 (即 l c a lca lca下面兩個節點)

首先此時 a a a 和 b b b 已經處于同一層;

設 d d d = d e p [ a ] dep[a] dep[a];

可以确定 a a a 和 b b b 的 l c a lca lca 的深度(設為 m d md md) 一定在 1 1 1 ~ 2 l o g 2 d = d 2^{log_2^d} = d 2log2d​=d 之間,

l c a lca lca 是 a a a 和 b b b 的第 d − m d d - md d−md 個父節點

設 k k k = l o g 2 d log_2^d log2d​ 即 l g [ d ] lg[d] lg[d],顯然 d d d <= 2 k 2^k 2k,

而 1 1 1 ~ d d d 任何一個數都可以由 2 0 2^0 20, 2 1 2^1 21, 2 2 2^2 22,…, 2 k 2^k 2k 其中的 c c c 個相加求得 !!!(相當于 k k k位二進制湊得 m d md md ) 顯然這是肯定的。

然而需要湊得 m d md md - 1 層(即 a a a 的第 d − m d − 1 d - md - 1 d−md−1 個 父節點) 的話,需要将 i = k i = k i=k 從大到小;如果 f a [ a ] [ i ] = = f a [ b ] [ i ] fa[a][i] == fa[b][i] fa[a][i]==fa[b][i],說明此時跳的話,會跳過頭

d − m d − 1 d - md - 1 d−md−1 一定小于目前的 2 i 2^i 2i,

是以 d − m d − 1 d - md - 1 d−md−1 一定能由 { 0 ∣ 1 } \{0 | 1\} {0∣1} * 2 i − 1 2^{i-1} 2i−1 + { 0 ∣ 1 } \{0 | 1\} {0∣1} * 2 i − 2 2^{i-2} 2i−2 + … + { 0 ∣ 1 } \{0 | 1\} {0∣1} * 2 1 2^{1} 21 + { 0 ∣ 1 } \{0 | 1\} {0∣1} * 2 0 2^0 20 湊得

for (int i = lg[dep[a]]; i >= 0; i--) {

    if(fa[a][i] != fa[b][i]){
        a = fa[a][i], b = fa[b][i];
    }
}
           

AC Code

#include <iostream>
#include <cstring>

using namespace std;

const int N = 5e5 + 10;

int e[N << 1], ne[N << 1], h[N], cnt;

void add(int a, int b){
    e[cnt] = b, ne[cnt] = h[a], h[a] = cnt++;
}

int n, m, s;
// 用于預處理log(i) 的值
int lg[N];
int dep[N];  // 記錄某節點的深度
int fa[N][20];  // 記錄某節點的祖先節點; fa[u][0]是父節點

// dfs預處理出來fa數組
void dfs(int u, int pre){

    dep[u] = dep[pre] + 1;
    fa[u][0] = pre;

    for (int i = 1; i <= (lg[dep[u]]); i++) {
        // u的2^i個祖先是u的2^(i-1)個祖先的2^(i-1)個祖先
        // 2^i = 2^(i-1) + 2^(i-1)
        fa[u][i] = fa[fa[u][i-1]][i-1];
    }

    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if(j == pre) continue;
        dfs(j, u);
    }
}

int lca(int a, int b){
    // 設a的深度較深
    if(dep[a] < dep[b]) swap(a, b);

    // 讓a,b到同一深度
    for (int i = lg[dep[a]]; i >= 0; i--) {
        // 利用倍增讓a向上走
        // 如果a的2^i個父節點的深度大于等于b,則a向上走
        if(dep[fa[a][i]] >= dep[b]) a = fa[a][i];
    }
    if(a == b) return a;

    // 找出a和b的lca的直接子節點(即lca下面兩個節點)
    for (int i = lg[dep[a]]; i >= 0; i--) {

        if(fa[a][i] != fa[b][i]){
            a = fa[a][i], b = fa[b][i];
        }
    }

    return fa[a][0]; // fa[b][0]
}

int main(){

    memset(h, -1, sizeof h);

    lg[1] = 0;

    for (int i = 0; i < N; ++i) {
        // lg 8 = lg7 + (1 << 3 == 8)
        lg[i] = lg[i-1] + (1 << (lg[i-1] + 1) == i);
    }

    scanf("%d%d%d", &n, &m, &s);

    int x, y;
    for (int i = 0; i < n - 1; ++i) {
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }

    dep[s] = 0;
    dfs(s, s);

    int a, b;
    for (int i = 0; i < m; ++i) {
        scanf("%d%d", &a, &b);
//        printf("%d  %d\n", dep[a], dep[b]);
        printf("%d\n", lca(a, b));
    }

    return 0;
}
           

樹上倍增應用題

#include <iostream>
#include <cstring>

using namespace std;

int n, cnt, e[1000], ne[1000], h[1000];
int lg[1000], dep[1000], wd[1000], fa[1000][30];

void add(int a, int b){

    e[cnt] = b, ne[cnt] = h[a], h[a] = cnt++;
}

void dfs(int u, int pre){

    dep[u] = dep[pre] + 1;

    fa[u][0] = pre;

    for (int i = 1; i <= lg[dep[u]]; ++i) {
        fa[u][i] = fa[fa[u][i-1]][i-1];
    }

    for (int i = h[u]; i != -1; i = ne[i]) {

        int j = e[i];
        if(j == pre) continue;
        dfs(j, u);
    }

    wd[dep[u]]++;
}


void lca(int a, int b){

    int x = a, y = b;
    
    // a 為深度深的
    if(dep[a] < dep[b]) swap(a, b);

    // a 上升到和 b 一樣高
    for (int i = lg[a]; i >= 0; i--) {

        if(dep[fa[a][i]] >= dep[b]) a = fa[a][i];
    }

    for (int i = lg[a]; i >= 0; i--) {
        if(fa[a][i] != fa[b][i]){
            a = fa[a][i], b = fa[b][i];
        }
    }

    // lca
    a = fa[a][0];

    cout << (dep[x] - dep[a]) * 2 + (dep[y] - dep[a]) << endl;
}

int main(){

    memset(h, -1, sizeof h);

    lg[1] = 0;
    for (int i = 2; i < 1000; ++i) {
        lg[i] = lg[i - 1] + (1 << (lg[i-1] + 1) == i);
    }

    cin >> n;

    int x, y;

    for (int i = 0; i < n - 1; ++i) {

        cin >> x >> y;
        add(x, y);
        add(y, x);
    }

    dfs(1, 1);

    int d = 0, w = 0;

    for (int i = 0; i < 1000; ++i) {
        d = max(d, dep[i]);
        w = max(w, wd[i]);
    }

    cout << d << endl;
    cout << w << endl;

    int u, v;

    cin >> u >> v;

    lca(u, v);

    return 0;
}
           

繼續閱讀