天天看點

ST(RMQ)算法(線上)求LCA

在此之前,我寫過另一篇部落格,是​​倍增(線上)求LCA​​。有興趣的同學可以去看一看。概念以及各種暴力就不在這裡說了,那篇部落格已經有介紹了。

ST(RMQ)算法線上求LCA

這個算法的思想,就是将LCA問題轉化成RMQ問題。

怎麼将LCA轉成RMQ?

我們首先用dfsO(N)周遊一遍。比如下圖:

ST(RMQ)算法(線上)求LCA

得到一個dfs序(從兒子回到父親也要算一遍):

1->2->4->7->4->8->4->2->5->2->6->9->6->10->6->2->1->3->1

可以簡單地了解成這樣:你一開始在根節點,一直向下走,發現盡頭就倒退,向另一個方向走。最後你還會回到根節點。你周遊這個樹的順序就是一個這樣的dfs序。

有沒有發現什麼規律?

設r[x]表示x在這個dfs序當中第一次出現的位置,deep[x]表示x的深度。

那麼可以發現,如果要求x和y的LCA,r[x]~r[y]這一段區間内一定有它們的LCA,而且還是區間中深度最小的那個。

這是為什麼?

隻要你懂dfs,簡單思考一下就能明白。到達x點後,再到y點,必須經過過它們的LCA,因為這是一棵樹,兩個點之間有且隻有一條路徑。

為什麼它在區間中深度最小?

因為dfs的原因,周遊以LCA(x,y)為根的子樹時,不周遊完所有以LCA(x,y)為根的點是不會回去的。然而x、y一定在以LCA(x,y)為根的子樹當中,是以這也是成立的。

具體怎麼做?

首先,用dfsO(n)求出dfs序、r數組和deep數組。

然後,套一個純的ST(RMQ)。設f[i][j]表示j~j+2^i-1的點當中,deep值最小的是哪個。

預處理做完了,接下來就可以線上O(1)回答詢問了。

注意事項

這個dfs序長度是2n-1的,原因:每個點經過的次數=兒子個數+1。那麼所有點的兒子個數一共有n-1,因為沒有根節點。所有是2n-1的。

線上O(1)回答的時候,有的人求對數使用log(x)/log(2)的形式。實際上沒必要,因為C++中有個東西叫log2,直接用就好。

代碼實作

​​例題 P3379【模闆】最近公共祖先(LCA)​​

#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
int n,_n,m,s;//_n是用來放元素進dfs序裡,最終_n=2n-1
struct EDGE
{
    int to;
    EDGE* las;
} e[1000001];//前向星存邊
EDGE* last[500001];
int sx[1000001];//順序,為dfs序
int f[21][1000001];//用于ST算法
int deep[500001];//深度
int r[500001];//第一次出現的位置
void dfs(int,int,int);
int min(int a,int b){return deep[a]<deep[b]?a:b;}
int query(int,int);
int main()
{
    scanf("%d%d%d",&n,&m,&s);
    int i,j=0,x,y;
    for (i=1;i<n;++i)
    {
        scanf("%d%d",&x,&y);
        e[++j]={y,last[x]};
        last[x]=e+j;
        e[++j]={x,last[y]};
        last[y]=e+j;
    }
    dfs(s,0,0);
    //以下是ST算法
    for (i=1;i<=_n;++i)
        f[0][i]=sx[i];
    int ni=int(log2(_n)),nj,tmp;
    for (i=1;i<=ni;++i)
    {
        nj=_n+1-(1<<i);
        tmp=1<<i-1;
        for (j=1;j<=nj;++j)
            f[i][j]=min(f[i-1][j],f[i-1][j+tmp]);
    }
    //以下是詢問,對于每次詢問,可以O(1)回答
    while (m--)
    {
        scanf("%d%d",&x,&y);
        printf("%d\n",query(r[x],r[y]));
    }
}
void dfs(int t,int fa,int de)
{
    sx[++_n]=t;
    r[t]=_n;
    deep[t]=de;
    EDGE* ei;
    for (ei=last[t];ei;ei=ei->las)
        if (ei->to!=fa)
        {
            dfs(ei->to,t,de+1);
            sx[++_n]=t;
        }
}
int query(int l,int r)
{
    if (l>r)
    {
        //交換
        l^=r;
        r^=l;
        l^=r;
    }
    int k=int(log2(r-l+1));
    return min(f[k][l],f[k][r-(1<<k)+1]);
}