天天看点

牛客国庆集训派对Day4 E 乒乓球 组合数学 快速数论变换ntt

看了3天fft和ntt,蒟蒻最终还是有点迷,所以就不专门写一篇关于fft和ntt的博客了,不会的请自行百度。

题目:https://ac.nowcoder.com/acm/contest/204/E

看看题目,期望!这东西,对于数学渣的我简直是噩梦,我在推导了整整一个小时,写了整整两张草稿纸后,总算是有了点思路,再配合题解(这才是重点好吗),才大致想出了这道题。

题目意思我就不复述了,我们对这个问题分析一下。按照我一开始的想法,我是从结果反推的(就是先枚举最后剩下的两个,再枚举上一个,直到枚举完整个序列),但我发现这个问题有几大困难:

1.需要枚举的情况太多,即使最后可以合并(可以发现这样枚举后,每两个相距一定的点的贡献被算了一定的次数,但非常复杂),复杂度也过不去。

2.对于枚举时还要考虑当前点删去时所处位置,因为它的位置会影响删去的它的贡献,所以这样考虑既具有前效性又具有后效性。(我太蒻啦!)

之前我们考虑的是对于每一个点的情况,现在我们换个角度:考虑每一对i , j,w[i] * w[j]对答案的贡献。

这时,我们可以发现,一对i , j,只有当i和j是i ~ j中最后删去的点时,他们的答案才会对最终答案有贡献。

这样我们就可以很愉快地将整个问题分成一些子问题了:对于一对i , j,ans + w[i] * w[j] * (i , j最后删掉的概率)。

那么只差概率,我们就可以完美地做出这道题的n ^ 2暴力了。(emmm,证明不会的。。。这我真帮不了你)贴个图233

牛客国庆集训派对Day4 E 乒乓球 组合数学 快速数论变换ntt

仔细观察,我们可以发现,这个式子好像只和j - i有关啊。仔细一想,还真是,我们只需要算得对于所有(j - i = k)的i , j的w[i] * w[j]的和再乘上上面的概率,这道题就完了呀。

这个时候,难题又摆在了面前,怎么实现前面的求所有(j - i = k)的i , j的w[i] * w[j]的和呢?多项式乘法,想都不想就考虑卷积,再看数据范围100000,有戏。

我们知道,卷积(两个多项式的乘积,这里我们不妨称它为S)是S[k] = ∑w[i] * w[j] ,(i + j = k),那么,怎么转化为(j - i)的形式呢?(我想这里想了一个中午)

看了大佬代码后我才发现,其实我们完全可以人为构造,让加法变成减法的形式。我们设w’[i] = w[n - i - 1](保证w’也是0 ~ n),那么对于w[i] * w[i + k](就是w[i] * w[j] , i与j的差为定值),它的值就为w[i] * w’[n - 1 - k + i],即为乘后的S[n - 1 - k](w与w’做ntt)。

这样,我们就保证了对于i , j差为定值的所有w[i] * w[j] 的值的和在乘后S数组的一固定位置了。然后,就一个O(n)加法就完事。附代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int G = 3;
inline char get_char()
{
	static char buf[1000000] , *p1 = buf , *p2 = buf;
	if (p1 == p2)
	{
		p2 = (p1 = buf) + fread(buf , 1 , 1000000 , stdin);
		if (p1 == p2)
		{
			return EOF;
		}
	}
	return *p1++;
}
inline int read()
{
	int res;
	char ch;
	while (!isdigit(ch = get_char()));
	res = ch - '0';
	while (isdigit(ch = get_char()))
	{
		res = res * 10 + ch - '0';
	}
	return res;
}
inline int Pow(int x , int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1)
		{
			res = (ll)res * x % mod;
		}
		x = (ll)x * x % mod;
		y >>= 1;
	}
	return res;
}
int N , n;
int rev[300010];
void init()
{
	int len = 0;
	while ((1 << len) <= n * 2)
	{
		len++;
	}
	N = (1 << len);
	for (int i = 0; i < N; i++)
	{
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
	}
}
void ntt(int *a , int re)
{
	for (int i = 0; i < N; i++)
	{
		if (rev[i] > i)
		{
			swap(a[i] , a[rev[i]]);
		}
	}
	for (int i = 2; i <= N; i <<= 1)
	{
		int mid = i >> 1;
		int wn = Pow(G , (mod - 1) / i);
		if (re)
		{
			wn = Pow(wn , mod - 2);
		}
		for (int j = 0; j < N; j += i)
		{
			int w = 1;
			for (int k = 0; k < mid; k++)
			{
				int t1 = a[j + k] , t2 = (ll)a[j + k + mid] * w % mod;
				a[j + k] = (t1 + t2) % mod;
				a[j + k + mid] = (t1 - t2 + mod) % mod;
				w = (ll)w * wn % mod;
			}
		}
	}
	if (re)
	{
		int inv = Pow(N , mod - 2);
		for (int i = 0; i < N; i++)
		{
			a[i] = (ll)a[i] * inv % mod;
		}
	}
}
int A[300010] , B[300010];
int main()
{
	scanf("%d" , &n);
	for (int i = 0; i < n; i++)
	{
		scanf("%d" , &A[i]);
	}
	for (int i = 0; i < n; i++)
	{
		B[i] = A[n - i - 1];
	}
	init();
	ntt(A , 0);
	ntt(B , 0);
	for (int i = 0; i < N; i++)
	{
		A[i] = (ll)A[i] * B[i] % mod;
	}
	ntt(A , 1);
	int ans = 0;
	for (int i = 2; i < n; i++)
	{
		ans += 2ll * A[n + i - 1] * Pow((ll)i * (i + 1) % mod , mod - 2) % mod;
		ans %= mod;
	}
	printf("%d\n" , ans);
}
           

继续阅读