天天看點

The Preliminary Contest for ICPC Asia Xuzhou 2019 G. Colorful String(回文自動機)

​​The Preliminary Contest for ICPC Asia Xuzhou 2019 G. Colorful String​​

題意

給一個字元串,找回文子串中不同字母數量,最後相加。(回文子串可以相同)

分析

處理字元串回文子串,很顯然回文自動機。

​​回文自動機​​
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define fuck(x) cout << (x) << endl
#define lson l, m, rt<<1
#define rson m+1, r, rt<<1|1
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int ALP = 30;
const int N = 3e5 + 10;
const int M = 1e4 + 10;

char str[N];
ll ans;     // 結果
int ans_a[30];  // dfs 的時候維護 字母出現的次數

struct PAM{
    int next[N][ALP];   // next指針,指向的串為目前串兩端加上同一個字元構成
    int fail[N];   // fail指針,失配後跳的地方
    int len[N];     // i 結點回文串的長度
    int s[N];      // 添加的字元
    int cnt[N];     //表示節點i表示的本質不同的串的個數(建樹時求出的不是完全的,最後count()函數跑一遍以後才是正确的)
    int last, n, p; // 上一個結點,字元個數,結點個數
    
    int newnode(int l) {
        for (int i = 0; i < 10; i++) 
            next[p][i] = 0;
        len[p] = l;
        return p++;
    }
    void init(){
        ans = p = 0;
        newnode(0);   // 偶跟
        newnode(-1);   // 奇根
        last = n = 0;
        cnt[0] = 0;
        s[n] = -1;     //開頭放一個字元集中沒有的字元,減少特判
        fail[0] = 1;    
    }
    int get_fail(int x){   // 失配後找盡量最長的
        while(s[n-len[x]-1] != s[n]) x = fail[x];
        return x;
    }
    void add(int c){
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);   //通過上一個回文串找這個回文串的比對位置
        if(!next[cur][c]){     //如果這個回文串沒有出現過,說明出現了一個新的本質不同的回文串
            int now = newnode( len[cur] + 2 ); // 建立節點 
            fail[now] = next[get_fail(fail[cur])][c]; //建立fail 指針
            next[cur][c] = now;
        }
        last = next[cur][c];
        cnt[last]++;
    }   
    void count () {
        for ( int i = p - 1 ; i >= 0 ; -- i ) cnt[fail[i]] += cnt[i] ;
    }
    void dfs(int rt, ll s){
        for(int i = 0; i < 26; i++){
            if(next[rt][i]){
                ll ts;
                if(ans_a[i]){   // 目前節點出現過就不用 + 1
                    ts = s;
                    ans = ans +  (ts * cnt[next[rt][i]]);
                    dfs(next[rt][i], ts);
                }else{
                    ans_a[i]++;
                    ts = s + 1;
                    ans = ans + (ts * cnt[next[rt][i]]);
                    dfs(next[rt][i], ts);
                    ans_a[i]--; // 回溯
                }
            }
        }
    }

}pam;


int main(){
    pam.init();
    scanf("%s", str);
    int len = strlen(str);
    for(int i = 0; i < len; i++){
        pam.add(str[i]);
    }
    pam.count();
    pam.dfs(0, 0);   // 偶根
    memset(ans_a, 0, sizeof(ans_a));
    pam.dfs(1, 0);   // 奇根
    printf("%lld\n", ans);
    return  0;
}