天天看点

【算法竞赛进阶指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 归并排序(详解)天才ACM

天才ACM

题目描述

给定一个整数 M,对于任意一个整数集合 S,定义“校验值”如下:

从集合 S 中取出 M 对数(即 2∗M 个数,不能重复使用集合中的数,如果 S 中的整数不够 M 对,则取到不能取为止),使得“每对数的差的平方”之和最大,这个最大值就称为集合 S 的“校验值”。

现在给定一个长度为 N 的数列 A 以及一个整数 T。

我们要把 A 分成若干段,使得每一段的“校验值”都不超过 T。

求最少需要分成几段。

输入格式

第一行输入整数 K,代表有 K 组测试数据。

对于每组测试数据,第一行包含三个整数 N,M,T 。

第二行包含 N 个整数,表示数列A1,A2…AN。

输出格式

对于每组测试数据,输出其答案,每个答案占一行。

数据范围

【算法竞赛进阶指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 归并排序(详解)天才ACM

输入样例:

2
5 1 49
8 2 1 7 9
5 1 64
8 2 1 7 9
           

输出样例:

2
1
           
难度:困难
时/空限制:10s / 64MB
来源:《算法竞赛进阶指南》

算法标签

倍增

思路

显然,对于一个集合S,取最大的M个数和最小的M个数,依次构成一对,取得的校验值最大。

那么为了使分成的段数尽可能少,我们需要使每段尽可能长。

在求每一段的校验值时我们需要将区间排序,时间复杂度为O(nlogn),对于一个左端点L我们需要知道在校验值不大于T的情况下,R最大是多少,如果直接遍历复杂度为O(n),总的时间复杂度就为

【算法竞赛进阶指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 归并排序(详解)天才ACM

,即使优化掉一个log也显然超时。

这时候我们就需要用到倍增。

倍增,字面上来讲就是成倍增长,我们知道任意整数都可以表示成若干个2的次幂项的和,倍增就是利用若干个长度为2的次幂的区间来拼出长度为k的区间,在满足条件时,尝试加入的区间长度p会以2的幂次进行增长,如果不满足条件,p每次会变成原来的一半,直到p变为0,因此时间复杂度为O(logn)。

为什么可以使用倍增?我们计算校验值是通过选m个最大和最小的数进行计算,那么[L,R + 1]的校验值一定比[L, R]的校验值大或者相等,因为前一个区间包含后一个区间所有的数,所以也就是说枚举R时存在单调性。

那为什么不用二分呢?二分端点的缺点是如果T的上限很小,每次右端点相比左端点只是向右移了一小段距离,那么这时就会退化成nlogn,还不如从前向后枚举更优,而倍增可以应对T的各种大小。另外,倍增出的右端点每次都会向右移动或者不变,这一性质可以用于优化排序。

我们在求校验值时可以用类似归并排序的方法,只需要对新加入的区间排序,然后与旧区间进行归并,这样就可以将排序复杂度降为O(n),那么总体复杂度就从

【算法竞赛进阶指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 归并排序(详解)天才ACM

降为

【算法竞赛进阶指南】CH0601 / Acwing 109 - Genius ACM - 倍增 + 归并排序(详解)天才ACM

具体实现看归并排序代码及注释

Accepted 12575 ms C++ 快速排序
Accepted 2546 ms C++ 归并排序

代码

快速排序

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<ll, ll> pll;
const int mod = 1e9 + 7;
const int N = 5e5 + 10;
const int INF = 0x3f3f3f3f;
ll qpow(ll base, ll n){ll ans = 1; while (n){if (n & 1) ans = ans * base % mod; base = base * base % mod; n >>= 1;} return ans;}
ll gcd(ll a, ll b){return b ? gcd(b, a % b) : a;}
ll a[N], b[N], c[N], t;
int n, m, old_r;

ll cal(int l, int r){
	r = min(r, n);
	int num = min(m, (r - l + 1) >> 1);
	for (int i = l; i <= r; ++ i) b[i] = a[i];
	sort(b + l, b + r + 1);
	ll ans = 0;
	for (int i = 0; i < num; ++ i) {
		ans += (b[r - i] - b[l + i]) * (b[r - i] - b[l + i]);
	}
	return ans;
}
int main()
{
	int k;
	cin >> k;
	while (k --){
		cin >> n >> m >> t;
		for (int i = 1; i <= n; ++ i){
			scanf("%lld", &a[i]);
		}
		int l, r, p;
		l = r = 1;
		b[l] = a[l];
		int ans = 0;
		while (l <= n){
			p = 1;
			while (p){
				ll num = cal(l, r + p);
				if (num <= t){
					r = min(r + p, n);
					if (r == n) break;
					p <<= 1;
				}else p >>= 1;
			}
			++ ans;
			l = r + 1;
		    r = l;
		}
		cout << ans << '\n';
	}
	return 0;
}
           

归并排序 

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<ll, ll> pll;
const int mod = 1e9 + 7;
const int N = 5e5 + 10;
const int INF = 0x3f3f3f3f;
ll qpow(ll base, ll n){ll ans = 1; while (n){if (n & 1) ans = ans * base % mod; base = base * base % mod; n >>= 1;} return ans;}
ll gcd(ll a, ll b){return b ? gcd(b, a % b) : a;}
ll a[N], b[N], c[N], t;
int n, m, old_r;
void merge(int l, int x, int r){
	int i = l, j = x + 1;
	for (int k = l; k <= r; ++ k){
		if (j > r || (i <= x && b[i] <= b[j])) c[k] = b[i ++];
		//新区间指针移到边界,或者旧区间指针还未移到边界并且所指元素小于新区间指针所指元素时
		else c[k] = b[j ++];
	}
}
ll cal(int l, int r){
	r = min(r, n);
	int num = min(m, (r - l + 1) >> 1);//元素不能重复利用
	for (int i = old_r + 1; i <= r; ++ i) b[i] = a[i];//加进新元素
	sort(b + old_r + 1, b + r + 1);//对新加进来的元素排序
	merge(l, old_r, r);//将新元素与已排好的一段序列合并
	ll ans = 0;
	for (int i = 0; i < num; ++ i) {
		ans += (c[r - i] - c[l + i]) * (c[r - i] - c[l + i]);//计算校验值
	}
	return ans;
}
int main()
{
	int k;
	cin >> k;
	while (k --){
		cin >> n >> m >> t;
		for (int i = 1; i <= n; ++ i){
			scanf("%lld", &a[i]);
		}
		int l, r, p;
		l = r = old_r = 1;
		b[l] = a[l];
		int ans = 0;
		while (l <= n){
			p = 1;
			while (p){
				ll num = cal(l, r + p);
				if (num <= t){//校验值小于t,说明范围可再扩大
					old_r = r = min(r + p, n);
					for (int i = l; i <= r; ++ i) b[i] = c[i];//b记录排好的序列,供下次merge使用
					if (r == n) break;
					p <<= 1;//p倍增
				}else p >>= 1;//校验值大于t时,不断右移p,直到p为0,说明范围不可扩
			}
			++ ans;//段数++
			l = r + 1;//下一点开始继续上述步骤
		    old_r = r = l;
		}
		cout << ans << '\n';
	}
	return 0;
}
           

继续阅读