天天看點

codeforces gym 100548G

求兩個字元串的公共回文子串個數,搞兩個回文樹,dfs一遍即可。

#pragma comment(linker, "/STACK:102400000,102400000")
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<bitset>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<functional>
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
LL ans;
int T;
char s[maxn];

struct PalindromicTree
{
  const static int maxn = 2e5 + 10;
  const static int size = 26;
  int next[maxn][size], sz, tot;
  int fail[maxn], len[maxn], last;
  LL cnt[maxn];
  char s[maxn];
  LL operator[](const int &x) { return cnt[x]; }
  void clear()
  {
    len[1] = -1; len[2] = 0;
    fail[1] = fail[2] = 1;
    cnt[1] = cnt[2] = tot = 0;
    last = (sz = 3) - 1;  
    memset(next[1], 0, sizeof(next[1]));
    memset(next[2], 0, sizeof(next[2]));
  }
  int Node(int length)
  {
    memset(next[sz], 0, sizeof(next[sz]));
    len[sz] = length;  cnt[sz] = 0;  return sz++;
  }
  int getfail(int x)
  {
    while (s[tot] != s[tot - len[x] - 1]) x = fail[x];
    return x;
  }
  int add(char pos)
  {
    int x = (s[++tot] = pos) - 'a', y = getfail(last);
    if (next[y][x]) { last = next[y][x]; }
    else {
      last = next[y][x] = Node(len[y] + 2);
      fail[last] = len[last] == 1 ? 2 : next[getfail(fail[y])][x];
    }
    return ++cnt[last];
  }
  void work()
  {
    for (int i = sz - 1; i > 2; i--)
    {
      if (fail[i] > 2) cnt[fail[i]] += cnt[i];
    }
  }
}work[2];

void dfs(int x, int y)
{
  ans += work[0][x] * work[1][y];
  for (int i = 0; i < 26; i++)
  {
    if (work[0].next[x][i] && work[1].next[y][i])
    {
      dfs(work[0].next[x][i], work[1].next[y][i]);
    }
  }
}

int main()
{
  scanf("%d", &T);
  for (int cas = 1; cas <= T; cas++)
  {
    for (int i = 0; i < 2; i++)
    {
      work[i].clear();
      scanf("%s", s);
      for (int j = 0; s[j]; j++)
      {
        work[i].add(s[j]);
      }
      work[i].work();
    }
    ans = 0;
    dfs(1, 1);
    dfs(2, 2);
    printf("Case #%d: %lld\n", cas, ans);
  }
  return 0;
}      

最近和同學讨論的時候又重新的想了這個問題,這裡提供兩個新的解決方案。

把第二個串也插進去一起統計的方法。

#pragma comment(linker, "/STACK:102400000,102400000")
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<bitset>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<functional>
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
LL ans;
int T;
char s[maxn];

struct PalindromicTree
{
  const static int maxn = 4e5 + 10;
  const static int size = 26;
  int next[maxn][size], sz, tot;
  int fail[maxn], len[maxn], last;
  LL cnt[maxn], t[maxn];
  char s[maxn];
  void clear()
  {
    len[1] = -1; len[2] = 0;
    fail[1] = fail[2] = 1;
    cnt[1] = cnt[2] = tot = 0;
    last = (sz = 3) - 2;
    memset(next[1], 0, sizeof(next[1]));
    memset(next[2], 0, sizeof(next[2]));
  }
  int Node(int length)
  {
    memset(next[sz], 0, sizeof(next[sz]));
    len[sz] = length;  t[sz] = cnt[sz] = 0;  return sz++;
  }
  int getfail(int x)
  {
    while (s[tot] != s[tot - len[x] - 1]) x = fail[x];
    return x;
  }
  int add(char pos, int kind)
  {
    int x = (s[++tot] = pos) - 'a', y = getfail(last);
    if (!(last = next[y][x]))
    {
      last = next[y][x] = Node(len[y] + 2);
      fail[last] = len[last] == 1 ? 2 : next[getfail(fail[y])][x];
    }
    return kind ? ++t[last] : ++cnt[last];
  }
  void work()
  {
    for (int i = sz - 1; i > 2; i--)
    {
      if (fail[i] > 2)
      {
        cnt[fail[i]] += cnt[i];
        t[fail[i]] += t[i];
      }
      ans += cnt[i] * t[i];
    }
  }
}work;

int main()
{
  scanf("%d", &T);
  for (int cas = 1; cas <= T; cas++)
  {
    work.clear();
    scanf("%s", s);
    for (int i = 0; s[i]; i++) work.add(s[i], 0);
    scanf("%s", s);
    work.tot = 0; work.last = 1;
    for (int i = 0; s[i]; i++) work.add(s[i], 1);
    ans = 0;
    work.work();
    printf("Case #%d: %lld\n", cas, ans);
  }
  return 0;
}      
#pragma comment(linker, "/STACK:102400000,102400000")
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<bitset>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<functional>
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
LL ans;
int T;
char s[maxn];

struct PalindromicTree
{
  const static int maxn = 2e5 + 10;
  const static int size = 26;
  int next[maxn][size], sz, tot;
  int fail[maxn], len[maxn], last;
  LL cnt[maxn], t[maxn];
  char s[maxn];
  void clear()
  {
    len[1] = -1; len[2] = 0;
    fail[1] = fail[2] = 1;
    cnt[1] = cnt[2] = tot = 0;
    last = (sz = 3) - 2;
    memset(next[1], 0, sizeof(next[1]));
    memset(next[2], 0, sizeof(next[2]));
  }
  int Node(int length)
  {
    memset(next[sz], 0, sizeof(next[sz]));
    len[sz] = length;  t[sz] = cnt[sz] = 0;  return sz++;
  }
  int getfail(int x)
  {
    while (s[tot] != s[tot - len[x] - 1]) x = fail[x];
    return x;
  }
  int add(char pos, int kind)
  {
    int x = (s[++tot] = pos) - 'a', y = getfail(last);
    if (!(last = next[y][x]))
    {
      if (!kind)
      {
        last = next[y][x] = Node(len[y] + 2);
        fail[last] = len[last] == 1 ? 2 : next[getfail(fail[y])][x];
      }
      else
      {
        while (!next[y][x] && y > 1) y = getfail(fail[y]);
        if (!(last = next[y][x])) last = 1;
      }
    }
    return kind ? ++t[last] : ++cnt[last];
  }
  void work()
  {
    for (int i = sz - 1; i > 2; i--)
    {
      if (fail[i] > 2)
      {
        cnt[fail[i]] += cnt[i];
        t[fail[i]] += t[i];
      }
      ans += cnt[i] * t[i];
    }
  }
}work;

int main()
{
  scanf("%d", &T);
  for (int cas = 1; cas <= T; cas++)
  {
    work.clear();
    scanf("%s", s);
    for (int i = 0; s[i]; i++) work.add(s[i], 0);
    scanf("%s", s);
    work.tot = 0; work.last = 1;
    for (int i = 0; s[i]; i++) work.add(s[i], 1);
    ans = 0;
    work.work();
    printf("Case #%d: %lld\n", cas, ans);
  }
  return 0;
}