天天看點

洛谷 P3799 妖夢拼木棒 加強版

題目描述

有n根木棒,現在從中選 4 根,想要組成一個正三角形,問有幾種選法?

答案對 998244353取模。

輸入格式

第一行一個整數 n。

第二行 n 個整數,第 i i i個整數 a i a_i ai​ 代表第 i 根木棒的長度。

輸出格式

一行一個整數代表答案。

輸入輸出樣例

輸入

4

1 1 2 2

輸出

1

說明/提示

資料規模與約定 0 ≤ a i ≤ 5 × 1 0 6 , 1 ≤ n ≤ 1 0 5 0≤a_i≤5×10^{6}, 1≤n≤10^5 0≤ai​≤5×106,1≤n≤105

解題思路

對于建構邊長為 x x x的正三角形,邊長為 x x x的木棍有 k k k根,兩根組合成一根長為 x x x 的木棍有 y y y 種,那麼建構邊長為x的正三角形有 y × C k 2 y\times C{}^{2}_{k} y×Ck2​

如何計算 y y y呢?

設 a 1 , a 2 , . . . , a m a_1, a_2, ..., a_m a1​,a2​,...,am​表示長為 i i i的木棍有多少根

對于兩根木棍拼起來長為 x x x的種類數分為以下兩種情況

當 x x x為奇數時

y = a 1 × a x − 1 + a 2 × a x − 2 + . . . + a x / 2 × a x / 2 + 1 y = a_1\times a_{x-1} + a_2\times a_{x-2} + ... + a_{x/2} \times a_{x/2+1} y=a1​×ax−1​+a2​×ax−2​+...+ax/2​×ax/2+1​

當 x x x為偶數時

y = a 1 × a x − 1 + a 2 × a x − 2 + . . . + a x / 2 × ( a x / 2 − 1 ) / 2 y = a_1\times a_{x-1} + a_2\times a_{x-2} + ... + a_{x/2} \times (a_{x/2} - 1)/ 2 y=a1​×ax−1​+a2​×ax−2​+...+ax/2​×(ax/2​−1)/2

用以上方法可以在 O ( n ∗ m ) O(n*m) O(n∗m)的時間内算出結果,這是基礎版的解法,我們進行了資料的加強,我們需要對y的計算進行一個簡化,于是我們想到了多項式乘法,由于多項式乘法可以用NTT優化到 O ( m l o g m ) O(mlogm) O(mlogm)

我們定義多項式

f ( x ) = a 1 x + a 2 x 2 + . . . + a i x i + . . . + a m x m f(x) = a_1x + a_2x^2 + ...+a_ix^i + ... + a_mx^m f(x)=a1​x+a2​x2+...+ai​xi+...+am​xm

g ( x ) = f ( x ) g(x) = f(x) g(x)=f(x)

h ( x ) = f ( x ) × g ( x ) h(x) = f(x) \times g(x) h(x)=f(x)×g(x)

我們發現,對 h ( x ) h(x) h(x)中的任意一項 b i x i b_ix^i bi​xi進行研究

其中 b i = a 1 × a i − 1 + a 2 × a i − 2 + . . . + a i − 1 × a 1 b_i = a_1 \times a_{i-1} + a_2\times a_{i-2} + ...+ a_{i-1} \times a_1 bi​=a1​×ai−1​+a2​×ai−2​+...+ai−1​×a1​

這樣一看,我們就可以将兩根木棍拼起來長為 x x x 的種類數 y y y 與 b i b_i bi​ 聯系起來。

當 x x x 為奇數時

y = b i / 2 y = b_i / 2 y=bi​/2

當 x x x 為偶數時

y = ( b i − a x / 2 × a x / 2 ) / 2 + a x / 2 × ( a x / 2 − 1 ) / 2 y = (b_i - a_{x/2} \times a_{x/2}) / 2 + a_{x/2} \times (a_{x/2} - 1)/ 2 y=(bi​−ax/2​×ax/2​)/2+ax/2​×(ax/2​−1)/2

是以我們隻需要以 O ( m l o g m ) O(mlogm) O(mlogm)的複雜度進行一個預處理就可以求出所有兩根木棍拼起來長為 x x x 的種類數 y y y

然後跑一遍就可以出結果了

#include <bits/stdc++.h>
#define ll long long
#define qc ios::sync_with_stdio(false); cin.tie(0);cout.tie(0)
#define fi first
#define se second
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define pb push_back
using namespace std;
const int MAXN = 2e6 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
const int G = 3,mod = 998244353;
int n,m,L,R[MAXN];
int A[MAXN],B[MAXN];
ll b[MAXN];
int qpow(int a,int b){
    int ans = 1;
    while(b){
        if(b&1)
            ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}
void NTT(int* a,int f){
    for (int i = 0; i < n; i++) if (i < R[i]) swap(a[i],a[R[i]]);
    for (int i = 1; i < n; i <<= 1){
        int gn = qpow(G,(mod - 1) / (i << 1));
        for (int j = 0; j < n; j += (i << 1)){
            int g = 1;
            for (int k = 0; k < i; k++,g = 1ll * g * gn % mod){
                int x = a[j + k],y = 1ll * g * a[j + k + i] % mod;
                a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
            }
        }
    }
    if (f == 1) return;
    int nv = qpow(n,mod - 2); reverse(a + 1,a + n);
    for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * nv % mod;
}

void solve(){
    cin >> n;
    int maxx = -1;
    m = n;
    for(int i = 1; i <= n; i++){
        int x;
        cin >> x;
        maxx = max(maxx, x);
        b[x]++;
    }
    for(int i = 0; i <= maxx; i++)
        A[i] = B[i] = b[i];
    m = n + m; for (n = 1; n <= m; n <<= 1) L++;
    for (int i = 0; i < n; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(A,1); NTT(B,1);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A,-1);

    for(int i = 1; i < MAXN; i++){
        if(i % 2 == 0){
            A[i] = 1ll * ((A[i] - (b[i/2] * b[i/2] % mod) + mod) % mod * qpow(2, mod - 2) % mod + b[i/2] * ((b[i/2] - 1 + mod) % mod) % mod * qpow(2, mod - 2) % mod) % mod;
        }
        else
            A[i] = 1ll * A[i] * qpow(2, mod - 2) % mod;
    }

    ll ans = 0;
    for(int i = 1; i <= maxx; i++){
        ans += 1ll * b[i] * ((b[i] - 1 + mod) % mod) % mod * qpow(2, mod - 2) % mod * A[i] % mod;
		ans %= mod;
    }
    cout << ans << endl;
}

int main()
{
    #ifdef ONLINE_JUDGE
    #else
       freopen("in.txt", "r", stdin);
       freopen("out.txt", "w", stdout);
    #endif

    qc;
    int T;
    // cin >> T;
    T = 1;
    while(T--){

        solve();
    }
    return 0;
}

           

繼續閱讀