题目大意
给定一个 1 至n的排列 A 。有m次操作,每次随机选择排列中的一个有序三元组轮换,求 m 次操作之内(包括m次)将其变成排列 B 的概率。
结果对998244353取模。
1≤n≤14,1≤m≤109
题目分析
首先可以发现,我们将 A 和B同时乘上同一个置换,从前者转移到后者的概率依然是不变的。因此我们考虑将 A 和B乘上其逆置换,使得 B 变为1到 n ,那么我们研究的就是新的A变为单位置换的问题。
其次,我们可以发现同构的置换可以合并。什么意思呢?我们将每一个排列的置换拆成若干个轮换,这个排列就用所有这些轮换的大小来表示。可以发现,虽然这样一种表示方法可以表示多个置换,但是这些置换对同一个三元组进行轮换都可以转移到同一个表示的状态里面。
那么这样会有多少种状态呢?其实就是 n 的划分数Dn。在 n≤14 时不超过 150 。
考虑使用矩阵乘法,那么我们需要预处理两两状态之间的转移:对于一个状态,我们可以构造出一个排列使其满足这个置换,然后在这个排列内枚举三元组,计算其转移到的状态,在转移矩阵累加上相应的概率( 1P3n )。
由于目标状态是单位置换,每个轮换大小都是 1 ,并且这样的表示能唯一表达目标状态,因此答案不会算重。
注意到达了目标状态就不能再转移了,因此我们要特殊处理转移矩阵中从目标状态出发的状态,让其只能以100%的概率转移到自己。
为了方便,我们可以先计算方案总数再乘上(P3n)m的逆元。
至于状态表示我们用哈希或map。
时间复杂度 O(Dnn3log2n+D3nlog2m) 。
代码实现
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>
using namespace std;
typedef vector<int> V;
const int P=;
const int N=;
const int S=;
struct matrix
{
int num[S][S];
int r,c;
}TRS,zero,f;
inline matrix operator*(matrix a,matrix b)
{
matrix ret;
memset(ret.num,,sizeof ret.num);
ret.r=a.r,ret.c=b.c;
for (int i=;i<ret.r;i++)
for (int j=;j<ret.c;j++)
for (int k=;k<a.c;k++)
(ret.num[i][j]+=l*a.num[i][k]*b.num[k][j]%P)%=P;
return ret;
}
inline matrix operator^(matrix x,int y)
{
matrix ret=zero;
for (;y;y>>=,x=x*x) if (y&) ret=ret*x;
return ret;
}
int A[N],B[N],trs[N];
map<V,int> id;
V state[S];
int n,m,cnt;
int quick_power(int x,int y)
{
int ret=;
for (;y;y>>=,x=l*x*x%P) if (y&) ret=l*ret*x%P;
return ret;
}
V tmp,nw;
void dfs(int lst,int sum)
{
if (sum==n)
{
state[++cnt]=tmp,id[tmp]=cnt;
return;
}
for (int i=lst;i<=n-sum;i++)
{
tmp.push_back(i);
dfs(i,sum+i);
tmp.pop_back();
}
}
bool vis[N];
int con[N],tot[N];
void makesta()
{
tot[]=;
for (int i=;i<=n;i++) vis[i]=;
for (int i=;i<=n;i++)
{
if (vis[i]) continue;
tot[++tot[]]=;
for (int x=i;!vis[x];x=con[x]) vis[x]=,tot[tot[]]++;
}
sort(tot+,tot++tot[]);
nw.clear();
for (int i=;i<=tot[];i++) nw.push_back(tot[i]);
}
void trans()
{
TRS.r=TRS.c=cnt;
for (int x=;x<=cnt;x++)
{
tmp=state[x];
int l=,r;
for (vector<int>::iterator it=tmp.begin();it!=tmp.end();it++,l=r)
{
r=l+*it;
for (int i=l+;i<=r;i++) con[i]=i!=r?i+:l+;
}
for (int i=;i<=n;i++)
for (int j=;j<=n;j++)
if (i!=j)
for (int k=;k<=n;k++)
if (i!=k&&k!=j)
{
swap(con[i],con[j]),swap(con[i],con[k]);
makesta();
TRS.num[x-][id[nw]-]++;
swap(con[i],con[k]),swap(con[i],con[j]);
}
}
memset(TRS.num[],,sizeof TRS.num[]);
TRS.num[][]=n*(n-)*(n-);
zero.r=zero.c=cnt;
for (int i=;i<cnt;i++) zero.num[i][i]=;
memcpy(con+,A+,n*sizeof(int));
makesta();
f.r=,f.c=cnt,f.num[][id[nw]-]=;
}
void calc(){f=f*(TRS^m);}
int main()
{
freopen("goodbye.in","r",stdin),freopen("goodbye.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=;i<=n;i++) scanf("%d",&A[i]);
for (int i=;i<=n;i++) scanf("%d",&B[i]),trs[B[i]]=i;
for (int i=;i<=n;i++) A[i]=trs[A[i]];
dfs(,),trans(),calc();
printf("%d\n",l*f.num[][]*quick_power(quick_power(n*(n-)*(n-),m),P-)%P);
fclose(stdin),fclose(stdout);
return ;
}