天天看點

牛客國慶集訓派對Day4 E 乒乓球 組合數學 快速數論變換ntt

看了3天fft和ntt,蒟蒻最終還是有點迷,是以就不專門寫一篇關于fft和ntt的部落格了,不會的請自行百度。

題目:https://ac.nowcoder.com/acm/contest/204/E

看看題目,期望!這東西,對于數學渣的我簡直是噩夢,我在推導了整整一個小時,寫了整整兩張草稿紙後,總算是有了點思路,再配合題解(這才是重點好嗎),才大緻想出了這道題。

題目意思我就不複述了,我們對這個問題分析一下。按照我一開始的想法,我是從結果反推的(就是先枚舉最後剩下的兩個,再枚舉上一個,直到枚舉完整個序列),但我發現這個問題有幾大困難:

1.需要枚舉的情況太多,即使最後可以合并(可以發現這樣枚舉後,每兩個相距一定的點的貢獻被算了一定的次數,但非常複雜),複雜度也過不去。

2.對于枚舉時還要考慮目前點删去時所處位置,因為它的位置會影響删去的它的貢獻,是以這樣考慮既具有前效性又具有後效性。(我太蒻啦!)

之前我們考慮的是對于每一個點的情況,現在我們換個角度:考慮每一對i , j,w[i] * w[j]對答案的貢獻。

這時,我們可以發現,一對i , j,隻有當i和j是i ~ j中最後删去的點時,他們的答案才會對最終答案有貢獻。

這樣我們就可以很愉快地将整個問題分成一些子問題了:對于一對i , j,ans + w[i] * w[j] * (i , j最後删掉的機率)。

那麼隻差機率,我們就可以完美地做出這道題的n ^ 2暴力了。(emmm,證明不會的。。。這我真幫不了你)貼個圖233

牛客國慶集訓派對Day4 E 乒乓球 組合數學 快速數論變換ntt

仔細觀察,我們可以發現,這個式子好像隻和j - i有關啊。仔細一想,還真是,我們隻需要算得對于所有(j - i = k)的i , j的w[i] * w[j]的和再乘上上面的機率,這道題就完了呀。

這個時候,難題又擺在了面前,怎麼實作前面的求所有(j - i = k)的i , j的w[i] * w[j]的和呢?多項式乘法,想都不想就考慮卷積,再看資料範圍100000,有戲。

我們知道,卷積(兩個多項式的乘積,這裡我們不妨稱它為S)是S[k] = ∑w[i] * w[j] ,(i + j = k),那麼,怎麼轉化為(j - i)的形式呢?(我想這裡想了一個中午)

看了大佬代碼後我才發現,其實我們完全可以人為構造,讓加法變成減法的形式。我們設w’[i] = w[n - i - 1](保證w’也是0 ~ n),那麼對于w[i] * w[i + k](就是w[i] * w[j] , i與j的差為定值),它的值就為w[i] * w’[n - 1 - k + i],即為乘後的S[n - 1 - k](w與w’做ntt)。

這樣,我們就保證了對于i , j差為定值的所有w[i] * w[j] 的值的和在乘後S數組的一固定位置了。然後,就一個O(n)加法就完事。附代碼:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int G = 3;
inline char get_char()
{
	static char buf[1000000] , *p1 = buf , *p2 = buf;
	if (p1 == p2)
	{
		p2 = (p1 = buf) + fread(buf , 1 , 1000000 , stdin);
		if (p1 == p2)
		{
			return EOF;
		}
	}
	return *p1++;
}
inline int read()
{
	int res;
	char ch;
	while (!isdigit(ch = get_char()));
	res = ch - '0';
	while (isdigit(ch = get_char()))
	{
		res = res * 10 + ch - '0';
	}
	return res;
}
inline int Pow(int x , int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1)
		{
			res = (ll)res * x % mod;
		}
		x = (ll)x * x % mod;
		y >>= 1;
	}
	return res;
}
int N , n;
int rev[300010];
void init()
{
	int len = 0;
	while ((1 << len) <= n * 2)
	{
		len++;
	}
	N = (1 << len);
	for (int i = 0; i < N; i++)
	{
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
	}
}
void ntt(int *a , int re)
{
	for (int i = 0; i < N; i++)
	{
		if (rev[i] > i)
		{
			swap(a[i] , a[rev[i]]);
		}
	}
	for (int i = 2; i <= N; i <<= 1)
	{
		int mid = i >> 1;
		int wn = Pow(G , (mod - 1) / i);
		if (re)
		{
			wn = Pow(wn , mod - 2);
		}
		for (int j = 0; j < N; j += i)
		{
			int w = 1;
			for (int k = 0; k < mid; k++)
			{
				int t1 = a[j + k] , t2 = (ll)a[j + k + mid] * w % mod;
				a[j + k] = (t1 + t2) % mod;
				a[j + k + mid] = (t1 - t2 + mod) % mod;
				w = (ll)w * wn % mod;
			}
		}
	}
	if (re)
	{
		int inv = Pow(N , mod - 2);
		for (int i = 0; i < N; i++)
		{
			a[i] = (ll)a[i] * inv % mod;
		}
	}
}
int A[300010] , B[300010];
int main()
{
	scanf("%d" , &n);
	for (int i = 0; i < n; i++)
	{
		scanf("%d" , &A[i]);
	}
	for (int i = 0; i < n; i++)
	{
		B[i] = A[n - i - 1];
	}
	init();
	ntt(A , 0);
	ntt(B , 0);
	for (int i = 0; i < N; i++)
	{
		A[i] = (ll)A[i] * B[i] % mod;
	}
	ntt(A , 1);
	int ans = 0;
	for (int i = 2; i < n; i++)
	{
		ans += 2ll * A[n + i - 1] * Pow((ll)i * (i + 1) % mod , mod - 2) % mod;
		ans %= mod;
	}
	printf("%d\n" , ans);
}
           

繼續閱讀