算法詳解
很長時間内都沒有能夠很了解KMP算法的精髓,尤其是很多書上包括《算法導論》沒有把next函數(亦或 π函數)講解的很透徹。
今天去看了matrix67大牛部落格中關于kmp部分的講解,有點兒醍醐灌頂的感覺,當然也隻是了解了一點淺層次罷了。我嘗試着用自己的語言說一下自己的了解,順便鍛煉一下自己渣一般的邏輯組織能力。。。。。。
下面開始正題吧~~~
我們知道單模字元串比對基本就是三種方法:
一、樸素枚舉。最壞時間複雜度O(mn)。
二、Rabin-Karp。需要O(m)的預處理。雖然最壞時間複雜度也是O(mn),但出現最壞情況的幾率比樸素法小很多,是以這種方法實際應用還是比較廣泛的。
三、Knuth-Morris-Pratt。即KMP算法。O(m)的預處理時間,O(n)的比對時間,非常高效。
先從樸素枚舉法說起吧。枚舉法就是從字元串的第一位開始,把每一位都作為開頭來與模式串逐位比對一次,如果比對失敗則開始下一位做開頭試,直到找到比對為止。
i = 1 2 3 4 5 6 ……
A = a a a b a a …
B = a a a b a c b
j = 1 2 3 4 5 6 7
我們知道,要優化一個算法就要知道它哪裡做了多餘的事情。從上面情況來看,在i == j == 6時目前比對失敗,樸素法會把模式串試着與i == 2做開頭比對,但是從我們前面已經得到的資訊已經可以得出,i == 2做開頭是不可能比對成功的,是以樸素枚舉在這裡就做了無用功(也可以說 i 指針做了無用的回溯)。(而且我們要發現在某種壞的情況下它做了相當多的無用功!比如這種情況:000000000000000000000001,模式串00000000001)
那麼我們應該想到一個優化了。抽象一下就是,當我們前面已經比對A[i-j..i-1] == B[1..j],而遇到A[i] != B[j+1]時,我們可以快速(O(1)時間内)找到一個新的j,使得新的A[i-j..i-1] == B[1..j],并且A[i] == B[j+1],這樣我們的 i 指針就可以不動,緊接着向下比對就好了。為了能夠快速找到這樣的j,我們使用一個輔助的next[]數組(π數組),next[j]記錄當B[j+1] != A[i]時,新的j的位置。
這樣我們的KMP算法的大緻代碼就是:
[cpp]
string T,P;
bool KMP()
{
bool flag = false;
int n = T.length();
int m = P.length();
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && B[j+1] != A[i]) j = next[j];
if (B[j+1] == A[i]) j ++;
if (j == m - 1)
{
flag = true; //比對成功
break;
//j = next[j];
}
}
return flag;
}[/cpp]
前面說過,next[]的值是由前面比對時已經得到的資訊得出來的,那麼我們發現next[]的值隻與模式串有關。(因為前面已經比對的地方A、B串是一樣的,是以隻由一個B串是可以得到資訊的),這樣我們就可以通過對模式串(B串)進行預處理來求出next[]。
那麼怎麼樣根據模式串求出next[]呢?舉個例子來說明:
A = a b a b a b …
B = a b a b a c b
當i == 6 ,j == 5時A[i] != B[j+1],那麼這時我們應該調整j了,手動調整一下可以發現新的j == 3,即next[5] = 3。
A = a b a b a b …
B = a b a b a c b
j = 1 2 3 4 5 6 7
發現沒有,新的j'要滿足情況,則需要滿足B[1..j‘]與B[j-j'+1..j]比對。
那麼我們可以得出一個樸素的O(m
2)的求next[]的方法了。對于每一個j,枚舉j',找出符合條件的最大的j'就行了。但是明顯這個算法是不夠優的(如果m很大怎麼辦?)
實際上上面一句話已經啟示我們到底該怎麼求next[]了,“滿足B[1..j']與B[j-j'+1..j]比對”,這就是一個模式串自身比對的過程啊!
我們看個求next[](這裡是π[])的圖:

我們看看next[j]可不可以由next[1..j]的資訊得出。假設目前要求next[6]。由j == next[5] == 3可得到,B[1..3] == B[3..5],而此時B[4] != B[6],是以此時比對失敗,需要回溯重新比對了。但我們不想回溯啊!前面說了回溯是做無用功啊!那麼我們是不是要找一個新的j'使得B[1..j'] == B[5-j'+1..5]。那麼怎麼找呢?由B[1..3] == B[3..5]可以知道我們可以這麼求:B[1..j'] == B[3-j'+1..3](想想為什麼可以這樣,可以看看上面那個圖),那不就是next[j]了麼~~~
寫一下代碼來加深了解:
void GetNext()
next[0] = -1;
int m = p.length();
for (int i = 1; i < m; i ++)
while(j > -1 && p[j+1] != p[i]) j = next[j];
if (p[j+1] == p[i]) j++;
next[i] = j;
}
[/cpp]
這樣我們的KMP算法就描述完了~~~
專題訓練:
♦POJ 3461 Oulipo(入門題)
光看樣例就一臉純模闆題的樣兒。。。。。。
#include
#include
#include
#include
using namespace std;
const int N = 1000100;
string A, B;
int pi[N];
void getp()
{
memset(pi,0,sizeof(pi));
int m = B.length();
pi[0] = -1;
int j = -1;
for (int i = 1;i < m; i ++)
{
while(j > -1 &&B[j+1] != B[i]) j = pi[j];
if (B[j+1] == B[i]) j++;
pi[i] = j;
}
}
int kmp()
{
int res = 0;
getp();
int n = A.length();
int m = B.length();
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && A[i] !=B[j+1]) j = pi[j];
if (A[i] == B[j+1]) j++;
if (j == m - 1)
{
res ++;
j = pi[j];
}
}
return res;
}
int main()
{
int tt;
scanf("%d",&tt);
while (tt--)
{
cin>>B;
cin>>A;
cout<
♦POJ 1226 Substrings (最長公共子串。KMP + 二分)
資料範圍很小,随便搞吧?=。=。其實就是二分答案長度,枚舉出該長度的每一個串,然後用KMP驗證。總複雜度O(n*len2*log(len))吧,可以接受的。
但是我怎麼可以這麼弱?各種小錯誤啊調試了2個小時了吧靠靠靠靠這麼個水題。。。。。。
#include
#include
#include
#include
using namespace std;
char s[110][110];
char p[110];
int pi[110];
void getpi()
{
int m = strlen(p);
int j = -1;
pi[0] = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && p[j+1] != p[i]) j = pi[j];
if (p[j+1] == p[i]) j++;
pi[i] = j;
}
}
bool kmp(int x)
{
getpi();
int n = strlen(s[x]);
int m = strlen(p);
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && s[x][i] != p[j+1]) j = pi[j];
if (s[x][i] == p[j+1]) j++;
if (j == m - 1)
{
return true;
}
}
return false;
}
int BS(int n)
{
if (n == 1) return strlen(s[0]);
int h = 0, t = strlen(s[0]) + 1;
while(h <= t)
{
memset(p,0,sizeof(p));
int fg = 1;
int mid = (h + t) >> 1;
for (int i = 0; i < strlen(s[0]) - mid + 1; i ++)
{
if (fg == n) break;
else fg = 1;
for (int k = 1; k < n; k ++)
{
for (int j = 0; j < mid; j ++)
p[j] = s[0][i + j];
if (kmp(k)) fg++;
else
{
for (int j = 0; j < mid; j ++)
p[mid - j - 1] = s[0][i + j];
if (kmp(k)) fg++;
else break;
}
}
}
if (fg == n)
h = mid + 1;
else t = mid - 1;
}
return h - 1;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
memset(p,0,sizeof(p));
int n;
scanf("%d",&n);
for (int i = 0; i < n; i ++)
scanf("%s",s[i]);
printf("%d\n",BS(n));
}
return 0;
}
♦POJ 2406 Power String (最小周期(重複)子串。加深對next函數的了解)
首先要了解next數組表示的是字元串字首的“對稱”程度。
然後記住這個結論:對于一個字元串s,如果len是(len - next[len])的倍數,那麼len - next[len]就是s的最小周期子串。
證明一下(解釋不太清楚>.<……):如果len是len-next[len]的倍數,假設m = len-next[len] ,那麼str[1-m] = str[m-2*m],……,以此類推下去,m肯定是str的最小重複單元的長度。假如len不是len-next[len]的倍數, 如果字首和字尾重疊,那麼最小重複單元肯定str本身了,如果字首和字尾不重疊,那麼str[m-2*m] != str[len-m,len],是以str[1-m] != str[m-2*m] ,最終肯定可以推理出最小重複單元是str本身,因為隻要不斷遞增m證明即可。
還是自己在紙上好好推演一下比較好。
#include
#include
#include
using namespace std;
const int N = 1000010;
int pi[N];
char p[N];
int getpi()
{
int m = strlen(p);
pi[0] = -1;
int j = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && p[j+1] != p[i]) j = pi[j];
if (p[j+1] == p[i]) j++;
pi[i] = j;
}
int x = m - 1 - pi[m - 1];
if (m % x == 0)
return x;
else return m;
}
int main()
{
while(scanf("%s",p)!=EOF)
{
if (p[0] == '.') break;
int l = strlen(p);
printf("%d\n",l / getpi());
}
return 0;
}
♦HDU 3336 Count the String (KMP+DP)
問題抽象:求所有字首在字元串中出現的次數。
暴力枚舉會達到O(n3)是不行的,枚舉字首然後KMP也會達到O(n2)。當然對于字首的情況我們應該利用好“字首數組”------next數組。實際上現在很多題也都不是考KMP而是考next數組的靈活運用。(KMP裸模闆題有什麼好考的。。。)
我們把問題分成幾個子問題來看,用DP解決:f[i]表示以第i位為結尾的字元串比對數。則sum = ∑f[i] 。
怎麼利用next數組呢?我們知道next[i] = j表示串B[1...j] == B[i-j+1...i],那麼一部分串(B[i-j+1...i]的字尾串)與字首的比對是可以通過j來求出來的,因為相等關系,是以這部分f[i]等價于f[next[i]]。這隻是一部分以i結尾的啊,那麼以[1...i-j]某處開頭、以 i 結尾的串有沒有可能呢?答案是不可能的,如果與字首比對成功那麼next[i]就不是j了(想想是不是~),當然要加上他本身(B[1..i])是整個串字首的情況,是以得出f[i] = f[next[i]] + 1。然後再算出sum就行了~~~
#include
#include
#include
#include
using namespace std;
const int N = 200010;
string s;
int f[N],pi[N];
void getpi()
{
int m = s.length();
pi[0] = -1;
int j = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && s[j+1] != s[i]) j = pi[j];
if (s[j+1] == s[i]) j++;
pi[i] = j;
}
}
int ff()
{
getpi();
int sum = 1;
int m = s.length();
f[0] = 1;
for (int i = 1; i < m; i ++)
{
f[i] = f[pi[i]] + 1;
sum += f[i];
sum %= 10007;
}
return sum;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n;
scanf("%d",&n);
cin>>s;
printf("%d\n",ff()%10007);
}
return 0;
}
(未完待續。。。)()
舉杯獨醉,飲罷飛雪,茫然又一年歲。 ------AbandonZHANG