天天看點

【POJ】3415 Common Substrings 【字尾數組+單調棧】

傳送門:【POJ】3415 Common Substrings

題目分析:

題目要求的實質是計算 A 的所有字尾和 B 的所有字尾之間的最長公共字首的長度,然後把最長公共字首長度不小于 k 的部分全部加起來(即A的字尾ai和B的字尾bj的最長公共字首為x且x大于等于k,則答案累加上x-k+1)。

由于枚舉所有ai和bj的話是O(n^2)的複雜度,是以我們需要一些技巧來優化。

比較好的方法是單調棧。

首先将兩個串連接配接到一起,中間用'$'隔開。然後構造字尾數組。

然後我們周遊height[i],求A的每一個字尾和在其之前的所有B的字尾對答案的貢獻。由LCP的性質得字尾i和字尾j的最長公共字首長度為i+1~j中height的最小值,如果i+1~j上所有的B的字尾和字尾A的LCP相同,那麼我們将他們縮成一塊,記錄個數以及高度即可。那麼這樣我們就得到了一塊塊的B字尾。使其為升序,保證單調,那麼新加入一個高度時我們就向前縮塊直到該高度比之前的塊的高度大位置,同時所有的B字尾的個數都累加到這個塊中,且如果sa[i - 1]為B字尾時我們額外的+1。然後此時如果sa[i]為A字尾,那麼在這個A字尾之前的所有B字尾對其的貢獻為每一塊的個數*(每一塊的高度-k+1)的和,為了快速的到答案,我們儲存字首和。

最後我們倒過來求一次B的每一個字尾和在其之前的所有A的字尾對答案的貢獻。

這樣所有的情況便都考慮到了。

這個思想甚妙啊!

代碼如下:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;

typedef long long LL ;

#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )
#define cpy( a , x ) memcpy ( a , x , sizeof a )

const int MAXN = 200005 ;

struct Node {
	int height , cnt ;
	LL sum ;
} ;

Node S[MAXN] ;
char s[MAXN] ;
int t1[MAXN] , t2[MAXN] , c[MAXN] , xy[MAXN] ;
int sa[MAXN] , rank[MAXN] , height[MAXN] ;
int k ;

int cmp ( int *r , int a , int b , int d ) {
	return r[a] == r[b] && r[a + d] == r[b + d] ;
}

void getHeight ( int n , int k = 0 ) {
	For ( i , 0 , n ) rank[sa[i]] = i ;
	rep ( i , 0 , n ) {
		if ( k ) -- k ;
		int j = sa[rank[i] - 1] ;
		while ( s[i + k] == s[j + k] ) ++ k ;
		height[rank[i]] = k ;
	}
}

void da ( int n , int m = 128 ) {
	int i , d , p , *x = t1 , *y = t2 , *t ;
	for ( i = 0 ; i < m ; ++ i ) c[i] = 0 ;
	for ( i = 0 ; i < n ; ++ i ) ++ c[x[i] = s[i]] ;
	for ( i = 1 ; i < m ; ++ i ) c[i] += c[i - 1] ;
	for ( i = n - 1 ; i >= 0 ; -- i ) sa[-- c[x[i]]] = i ;
	for ( d = 1 , p = 0 ; p < n ; d <<= 1 , m = p ) {
		for ( i = n - d , p = 0 ; i < n ; ++ i ) y[p ++] = i ;
		for ( i = 0 ; i < n ; ++ i ) if ( sa[i] >= d ) y[p ++] = sa[i] - d ;
		for ( i = 0 ; i < m ; ++ i ) c[i] = 0 ;
		for ( i = 0 ; i < n ; ++ i ) ++ c[xy[i] = x[y[i]]] ;
		for ( i = 1 ; i < m ; ++ i ) c[i] += c[i - 1] ;
		for ( i = n - 1 ; i >= 0 ; -- i ) sa[-- c[xy[i]]] = y[i] ;
		for ( t = x , x = y , y = t , p = 0 , x[sa[0]] = p ++ , i = 1 ; i < n ; ++ i ) {
			x[sa[i]] = cmp ( y , sa[i - 1] , sa[i] , d ) ? p - 1 : p ++ ;
		}
	}
	getHeight ( n - 1 ) ;
}

void solve () {
	int n1 , n2 , n ;
	LL ans = 0 ;
	scanf ( "%s" , s ) ;
	n1 = strlen ( s ) ;
	s[n1] = '$' ;
	scanf ( "%s" , s + n1 + 1 ) ;
	n2 = strlen ( s + n1 + 1 ) ;
	n = n1 + 1 + n2 ;
	da ( n + 1 ) ;
	int top = 0 ;
	For ( i , 2 , n ) {
		if ( height[i] < k ) top = 0 ;
		else {
			int cnt = 0 ;
			while ( top && height[i] <= S[top - 1].height ) cnt += S[-- top].cnt ;
			S[top].cnt = cnt + ( sa[i - 1] > n1 ) ;
			S[top].height = height[i] ;
			S[top].sum = top ? S[top - 1].sum : 0 ;
			S[top].sum += ( LL ) ( S[top].height - k + 1 ) * S[top].cnt ;
			if ( sa[i] < n1 ) ans += S[top].sum ;
			top ++ ;
		}
	}
	top = 0 ;
	For ( i , 2 , n ) {
		if ( height[i] < k ) top = 0 ;
		else {
			int cnt = 0 ;
			while ( top && height[i] <= S[top - 1].height ) cnt += S[-- top].cnt ;
			S[top].cnt = cnt + ( sa[i - 1] < n1 ) ;
			S[top].height = height[i] ;
			S[top].sum = top ? S[top - 1].sum : 0 ;
			S[top].sum += ( LL ) ( S[top].height - k + 1 ) * S[top].cnt ;
			if ( sa[i] > n1 ) ans += S[top].sum ;
			top ++ ;
		}
	}
	printf ( "%lld\n" , ans ) ;
}

int main () {
	while ( ~scanf ( "%d" , &k ) && k ) solve () ;
	return 0 ;
}