天天看點

hdu 5157 Harry and magic string(manacher+dp)

我們可以先用mancher算法對字元串進行處理,把以每個點為中心的回文串半徑求出來,然後進行處理。

加入對以p為中心的點,從p-r[i]+1~p都是回文串的開頭,那麼對于每個回文串(開頭是j)隻要記錄結尾從1~j-1的回文串個數,我們可以用dp記錄以每個點為結尾的回文串個數,s[i]=sigma(dp[i]),則是結尾從1~j-1的回文串個數。那麼對這個中心點來說一共的回文串對應該有:s[p-r[i]]+...+s[p-1]個,那麼我們可以繼續用一個數組s1[i]求s[i]的字首和,那麼總複雜度是O(n)。

至于dp[i]怎麼求,你已經知道了半徑,那從p~p+r[i]-1這些點的dp值都要加1,可以用樹狀數組來維護。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long LL;
char c[200005],f[100005];
int r[200005],bit[200005];
int d[200005];
LL s[200005],s1[200005];
int low(int n){return n&(-n);}
void merg(int p,int n,int k){
    while(p<=n){
        bit[p]+=k;
        p=p+low(p);
    }
}
int sum(int p){
    int s=0;
    while(p>0){
        s+=bit[p];
        p=p-low(p);
    }
    return s;
}
void mancher(int n){
    int i,id,mx;
    r[0]=1;
    mx=0;
    id=0;
    for(i=1;i<=2*n;i++){
        if(i>=mx) r[i]=1;
        else r[i]=min(r[id-(i-id)],mx-i);
        while(i-r[i]>=0&&i+r[i]<=2*n&&c[i-r[i]]==c[i+r[i]]) r[i]++;
        if(i+r[i]>mx){
            mx=i+r[i];
            id=i;
        }
    }
}
void get_back(int n){
    int i,j,p;
    memset(bit,0,sizeof(bit));
    memset(s,0,sizeof(s));
    memset(s1,0,sizeof(s1));
    for(i=1;i<=2*n;i++){
        p=i+r[i]-1;
        if(i%2==0){
            if(p>i){
                merg(i/2+1,n,1);
                merg(p/2+1,n,-1);
            }
        }
        else{
                merg((i+1)/2,n,1);
                merg(p/2+1,n,-1);
        }
    }
    s[0]=0;
    s1[0]=0;
    for(i=1;i<=n;i++){
        d[i]=sum(i);
        s[i]=s[i-1]+d[i];
        s1[i]=s1[i-1]+s[i];
    }
}
void work(int n){
    LL s=0;
    int i,j;
    for(i=1;i<=2*n;i++){
        if(i%2==0){
            if(r[i]>1){
                if((i-r[i]+1)/2!=0)
                s+=s1[i/2-1]-s1[(i-r[i]+1)/2-1];
                else s+=s1[i/2-1];
            }
        }
        else{
            if((i-r[i]+1)/2!=0)
            s+=s1[(i+1)/2-1]-s1[(i-r[i]+1)/2-1];
            else s+=s1[(i+1)/2-1];
        }
    }
    cout<<s<<endl;
}
int main()
{
    int i,j,n;
    while(scanf("%s",f)!=EOF){
        n=strlen(f);
        c[0]='#';
        for(i=1;i<=n;i++){
            c[2*i]='#';
            c[2*i-1]=f[i-1];
        }
        mancher(n);
        get_back(n);
        work(n);
    }
}