天天看点

[BZOJ]4180: 字符串计数 SAM+矩阵乘法+二分

Description

SD有一名神犇叫做Oxer,他觉得字符串的题目都太水了,于是便出了一道题来虐蒟蒻yts1999。

他给出了一个字符串T,字符串T中有且仅有4种字符 ‘A’, ‘B’, ‘C’, ‘D’。现在他要求蒟蒻yts1999构造一个新的字符串S,构造的方法是:进行多次操作,每一次操作选择T的一个子串,将其加入S的末尾。

对于一个可构造出的字符串S,可能有多种构造方案,Oxer定义构造字符串S所需的操作次数为所有构造方案中操作次数的最小值。

Oxer想知道对于给定的正整数N和字符串T,他所能构造出的所有长度为N的字符串S中,构造所需的操作次数最大的字符串的操作次数。

蒟蒻yts1999当然不会做了,于是向你求助。

Solution

如果知道 S S S,那么最小操作次数肯定是在SAM上尽量走,直到没有对应的儿子为止。建出SAM后,可以一次拓扑排序处理出 f x , y f_{x,y} fx,y​表示 x x x这个字母走到没有 y y y儿子节点的最短路。直接求不好求,考虑二分答案,那么就可以用矩阵乘法求出走 m i d mid mid次的最短长度,看是否 &lt; n &lt;n <n即可。

Code

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const LL inf=1e18+1;
LL read()
{
    LL x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
LL n;int m;
char str[Maxn];
int last=1,tot=1,son[Maxn<<1][4],par[Maxn<<1],mx[Maxn<<1];
void extend(int x)
{
    int p=last,np=++tot;mx[np]=mx[p]+1;
    while(p&&!son[p][x])son[p][x]=np,p=par[p];
    if(!p)par[np]=1;
    else
    {
        int q=son[p][x];
        if(mx[p]+1==mx[q])par[np]=q;
        else
        {
            int nq=++tot;mx[nq]=mx[p]+1;
            for(int i=0;i<4;i++)son[nq][i]=son[q][i];
            par[nq]=par[q];
            par[q]=par[np]=nq;
            while(son[p][x]==q)son[p][x]=nq,p=par[p];
        }
    }
    last=np;
}
LL f[Maxn<<1][4];int deg[Maxn<<1];
LL ans=0;int cnt[4];
struct Matrix{LL v[4][4];}one,M,st;
Matrix operator*(Matrix a,Matrix b)
{
    Matrix c;
    for(int i=0;i<4;i++)
    for(int j=0;j<4;j++)
    {
        c.v[i][j]=inf;
        for(int k=0;k<4;k++)
        c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
    }
    return c;
}
Matrix Pow(Matrix x,LL y)
{
    Matrix re=one,t=x;
    while(y)
    {
        if(y&1)re=re*t;
        y>>=1;t=t*t;
    }
    return re;
}
bool check(LL o)
{
    memset(st.v,0,sizeof(st.v));
    st=st*Pow(M,o);
    LL mn=inf;
    for(int i=0;i<4;i++)mn=min(mn,st.v[0][i]);
    if(mn<n)return true;
    return false;
}
vector<int>h[Maxn<<1];
int main()
{
    n=read();
    for(int i=0;i<4;i++)one.v[i][i]=1;
    scanf("%s",str+1);m=strlen(str+1);
    for(int i=1;i<=m;i++)extend(str[i]-'A'),cnt[str[i]-'A']++;
    for(int i=1;i<=tot;i++)
    for(int j=0;j<4;j++)
    f[i][j]=inf;
    for(int i=1;i<=tot;i++)
    for(int j=0;j<4;j++)
    if(son[i][j])h[son[i][j]].push_back(i),deg[i]++;
    else f[i][j]=1;
    queue<int>q;
    for(int i=1;i<=tot;i++)
    if(!deg[i])q.push(i);
    while(!q.empty())
    {
        int x=q.front();q.pop();
        for(int i=0;i<h[x].size();i++)
        {
            int y=h[x][i];
            deg[y]--;
            if(!deg[y])q.push(y);
            for(int j=0;j<4;j++)f[y][j]=min(f[y][j],f[x][j]+1);
        }
    }
    for(int i=0;i<4;i++)
    for(int j=0;j<4;j++)
    {
        if(!son[1][i])M.v[i][j]=inf;
        else M.v[i][j]=f[son[1][i]][j];
    }
    LL l=1,r=n;
    while(l<=r)
    {
        LL mid=l+r>>1;
        if(check(mid))l=mid+1;
        else r=mid-1;
    }
    printf("%lld",l);
}
           

继续阅读