天天看點

【算法競賽進階指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 歸并排序(詳解)天才ACM

天才ACM

題目描述

給定一個整數 M,對于任意一個整數集合 S,定義“校驗值”如下:

從集合 S 中取出 M 對數(即 2∗M 個數,不能重複使用集合中的數,如果 S 中的整數不夠 M 對,則取到不能取為止),使得“每對數的差的平方”之和最大,這個最大值就稱為集合 S 的“校驗值”。

現在給定一個長度為 N 的數列 A 以及一個整數 T。

我們要把 A 分成若幹段,使得每一段的“校驗值”都不超過 T。

求最少需要分成幾段。

輸入格式

第一行輸入整數 K,代表有 K 組測試資料。

對于每組測試資料,第一行包含三個整數 N,M,T 。

第二行包含 N 個整數,表示數列A1,A2…AN。

輸出格式

對于每組測試資料,輸出其答案,每個答案占一行。

資料範圍

【算法競賽進階指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 歸并排序(詳解)天才ACM

輸入樣例:

2
5 1 49
8 2 1 7 9
5 1 64
8 2 1 7 9
           

輸出樣例:

2
1
           
難度:困難
時/空限制:10s / 64MB
來源:《算法競賽進階指南》

算法标簽

倍增

思路

顯然,對于一個集合S,取最大的M個數和最小的M個數,依次構成一對,取得的校驗值最大。

那麼為了使分成的段數盡可能少,我們需要使每段盡可能長。

在求每一段的校驗值時我們需要将區間排序,時間複雜度為O(nlogn),對于一個左端點L我們需要知道在校驗值不大于T的情況下,R最大是多少,如果直接周遊複雜度為O(n),總的時間複雜度就為

【算法競賽進階指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 歸并排序(詳解)天才ACM

,即使優化掉一個log也顯然逾時。

這時候我們就需要用到倍增。

倍增,字面上來講就是成倍增長,我們知道任意整數都可以表示成若幹個2的次幂項的和,倍增就是利用若幹個長度為2的次幂的區間來拼出長度為k的區間,在滿足條件時,嘗試加入的區間長度p會以2的幂次進行增長,如果不滿足條件,p每次會變成原來的一半,直到p變為0,是以時間複雜度為O(logn)。

為什麼可以使用倍增?我們計算校驗值是通過選m個最大和最小的數進行計算,那麼[L,R + 1]的校驗值一定比[L, R]的校驗值大或者相等,因為前一個區間包含後一個區間所有的數,是以也就是說枚舉R時存在單調性。

那為什麼不用二分呢?二分端點的缺點是如果T的上限很小,每次右端點相比左端點隻是向右移了一小段距離,那麼這時就會退化成nlogn,還不如從前向後枚舉更優,而倍增可以應對T的各種大小。另外,倍增出的右端點每次都會向右移動或者不變,這一性質可以用于優化排序。

我們在求校驗值時可以用類似歸并排序的方法,隻需要對新加入的區間排序,然後與舊區間進行歸并,這樣就可以将排序複雜度降為O(n),那麼總體複雜度就從

【算法競賽進階指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 歸并排序(詳解)天才ACM

降為

【算法競賽進階指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 歸并排序(詳解)天才ACM

具體實作看歸并排序代碼及注釋

Accepted 12575 ms C++ 快速排序
Accepted 2546 ms C++ 歸并排序

代碼

快速排序

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<ll, ll> pll;
const int mod = 1e9 + 7;
const int N = 5e5 + 10;
const int INF = 0x3f3f3f3f;
ll qpow(ll base, ll n){ll ans = 1; while (n){if (n & 1) ans = ans * base % mod; base = base * base % mod; n >>= 1;} return ans;}
ll gcd(ll a, ll b){return b ? gcd(b, a % b) : a;}
ll a[N], b[N], c[N], t;
int n, m, old_r;

ll cal(int l, int r){
	r = min(r, n);
	int num = min(m, (r - l + 1) >> 1);
	for (int i = l; i <= r; ++ i) b[i] = a[i];
	sort(b + l, b + r + 1);
	ll ans = 0;
	for (int i = 0; i < num; ++ i) {
		ans += (b[r - i] - b[l + i]) * (b[r - i] - b[l + i]);
	}
	return ans;
}
int main()
{
	int k;
	cin >> k;
	while (k --){
		cin >> n >> m >> t;
		for (int i = 1; i <= n; ++ i){
			scanf("%lld", &a[i]);
		}
		int l, r, p;
		l = r = 1;
		b[l] = a[l];
		int ans = 0;
		while (l <= n){
			p = 1;
			while (p){
				ll num = cal(l, r + p);
				if (num <= t){
					r = min(r + p, n);
					if (r == n) break;
					p <<= 1;
				}else p >>= 1;
			}
			++ ans;
			l = r + 1;
		    r = l;
		}
		cout << ans << '\n';
	}
	return 0;
}
           

歸并排序 

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<ll, ll> pll;
const int mod = 1e9 + 7;
const int N = 5e5 + 10;
const int INF = 0x3f3f3f3f;
ll qpow(ll base, ll n){ll ans = 1; while (n){if (n & 1) ans = ans * base % mod; base = base * base % mod; n >>= 1;} return ans;}
ll gcd(ll a, ll b){return b ? gcd(b, a % b) : a;}
ll a[N], b[N], c[N], t;
int n, m, old_r;
void merge(int l, int x, int r){
	int i = l, j = x + 1;
	for (int k = l; k <= r; ++ k){
		if (j > r || (i <= x && b[i] <= b[j])) c[k] = b[i ++];
		//新區間指針移到邊界,或者舊區間指針還未移到邊界并且所指元素小于新區間指針所指元素時
		else c[k] = b[j ++];
	}
}
ll cal(int l, int r){
	r = min(r, n);
	int num = min(m, (r - l + 1) >> 1);//元素不能重複利用
	for (int i = old_r + 1; i <= r; ++ i) b[i] = a[i];//加進新元素
	sort(b + old_r + 1, b + r + 1);//對新加進來的元素排序
	merge(l, old_r, r);//将新元素與已排好的一段序列合并
	ll ans = 0;
	for (int i = 0; i < num; ++ i) {
		ans += (c[r - i] - c[l + i]) * (c[r - i] - c[l + i]);//計算校驗值
	}
	return ans;
}
int main()
{
	int k;
	cin >> k;
	while (k --){
		cin >> n >> m >> t;
		for (int i = 1; i <= n; ++ i){
			scanf("%lld", &a[i]);
		}
		int l, r, p;
		l = r = old_r = 1;
		b[l] = a[l];
		int ans = 0;
		while (l <= n){
			p = 1;
			while (p){
				ll num = cal(l, r + p);
				if (num <= t){//校驗值小于t,說明範圍可再擴大
					old_r = r = min(r + p, n);
					for (int i = l; i <= r; ++ i) b[i] = c[i];//b記錄排好的序列,供下次merge使用
					if (r == n) break;
					p <<= 1;//p倍增
				}else p >>= 1;//校驗值大于t時,不斷右移p,直到p為0,說明範圍不可擴
			}
			++ ans;//段數++
			l = r + 1;//下一點開始繼續上述步驟
		    old_r = r = l;
		}
		cout << ans << '\n';
	}
	return 0;
}
           

繼續閱讀