天天看點

Palindromic Tree

一個很厲害的算法,具體怎麼實作的不是很清楚,本着不求甚解的态度,知道怎麼用就好了。

next[][]:類似于字典樹,指向目前字元串在兩段同時加上一個字元

fail[] fail指針:類似于AC自動機,傳回失配後與目前i結尾的最長回文串本質上不同的最長回文字尾

cnt[]: 在最後統計後它可以表示形如以i為結尾的回文串中最長的那個串個數

num[]: 表示以i結尾的回文串的種類數

len[]: 表示以i為結尾的最長回文串長度

s[]: 存放添加的字元

last: 表示上一個添加的字元的位置

n: 表示字元數組的第幾位

p: 表示樹中節點的指針

本質不同的回文字元串:p-2 (減去兩個根節點)

統計所有回文串的個數: ∑ 2 p − 1 n u m [ i ] \sum^{p-1}_{2} num[i] ∑2p−1​num[i]

2019牛客暑期多校訓練營(第六場)C Palindrome Mouse

對于每一個節點p來說,他的nextt節點都是以p節點為子串的回文串,他的fail節點都是p的回文子串,那麼假設對于p來說,如果它向下有numn[p],向上有numc[p]個,那麼這個節點的貢獻就是numn[p]*numc[p];

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1000 * 100 + 10;
const int MAXN = 100005;
char s[MAXN];

struct Palindromic_Tree{
	int nextt[N][26];
	int cnt[N];
	int fail[N];
	int len[N];
	int num[N];
	int s[N];
	int n;
	int p;
	int last;

	int newnode(int lent) {
		for (int i = 0; i < 26; i++)nextt[p][i] = 0;
		cnt[p] = 0;
		len[p] = lent;
		num[p] = 0;
		return p++;
	}

	void init() {
		p = 0; 
		newnode(0); newnode(-1);
		last = 0;
		n = 0;
		s[n] = -1;
		fail[0] = 1;
	}

	int get_fail(int x) {
		while (s[n - len[x] - 1] != s[n])x = fail[x];
		return x;
	}

	void add(int c) {
		c -= 'a';
		s[++n] = c;
		int cur = get_fail(last);
		if (!nextt[cur][c]) {
			int now = newnode(len[cur]+2);
			fail[now] = nextt[get_fail(fail[cur])][c];
			nextt[cur][c] = now;
			num[now] = num[fail[now]] + 1;
		}
		last = nextt[cur][c];
		cnt[last]++;
	}

	void count() {
		for (int i = p - 1; i >= 0; i--)cnt[fail[i]] += cnt[i];
	}
	int numc[N], numn[N],vis[N];

	int dfs(int x) {
		numn[x] = 1;
		numc[x] = 0;
		for (int t = x; !vis[t] && t >1; t = fail[t])vis[t] = x,numc[x]++;
		for (int i = 0; i < 26; i++) {
			if (nextt[x][i] == 0)continue;
			numn[x] += dfs(nextt[x][i]);
		}
		for (int t = x; vis[t] ==x && t >1; t = fail[t]) vis[t] = 0;
		return numn[x];
	}

	ll solve() {
		ll ans = 0;
		dfs(0); dfs(1);
		for (int i = 2; i < p; i++)ans += 1LL * numc[i] * numn[i];
		return ans-p+2;
	}
}T;

int main() {
	int t; scanf("%d",&t);
	int casen=1;
	while (t--){
		scanf("%s",s);
		int len = strlen(s);
		T.init();
		for (int i = 0; i < len; i++)T.add(s[i]);
		printf("Case #%d: %lld\n",casen++,T.solve());
	}
	return 0;
}