时间: 1000ms / 空间: 165536KiB
描述
N柱砖,希望有连续K柱的高度是一样的.
你可以选择以下两个动作
1:从某柱砖的顶端拿一块砖出来,丢掉不要了.
2:从仓库中拿出一块砖,放到另一柱.仓库无限大.
现在希望用最小次数的动作完成任务.
输入格式
第一行给出N,K. (1 ≤ k ≤ n ≤ 100000),
下面N行,每行代表这柱砖的高度.0 ≤ hi ≤ 1000000
输出格式
最小的动作次数
测试样例1
输入
5 3
3
9
2
3
1
输出
2
备注
原题还要求输出结束状态时,每柱砖的高度.本题略去.
题解
我的做法是平衡树。
首先“连续K柱” 不用考虑什么大于K(不知道我当时脑子里问什么想到了这个问题,虽然我瞬间就解决了),应为若满足">K"代价一定不会更少,只可能相同或更多。
此时第一个子问题是(可化为)“将K个数经过加减变化,变成相同的数,其中变化总量为代价,要求代价最小”。因为变化量以“绝对值”形式计入代价,所以就是要我们找“中位数”
第二个子问题是“如何在动态的K个数中找中位数”。其中所谓动态即每次改变一个数。综合以上两点我想到了平衡树。子问题一就是区间第k大,子问题二是平衡树上插入删除。至于计算代价,平衡树可以维护子树权值和,在找第K大时可以计算小于这个数的所有数的和。
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<cmath>
#include<algorithm>
#define MAXN 100005
#define ll long long
using namespace std;
int n,m,po,h[MAXN],size,root;
ll smin=0;
struct tree{int l,r,rnd,wt,son;ll sum,v;} tr[MAXN];
void update(int w){
tr[w].son=tr[tr[w].r].son+tr[tr[w].l].son+tr[w].wt;
tr[w].sum=tr[tr[w].r].sum+tr[tr[w].l].sum+tr[w].wt*tr[w].v;
}
void lturn(int &w){
int t=tr[w].r; tr[w].r=tr[t].l; tr[t].l=w; update(w); update(t); w=t;
}
void rturn(int &w){
int t=tr[w].l; tr[w].l=tr[t].r; tr[t].r=w; update(w); update(t); w=t;
}
void insert(int &w,ll value){
if(w==0){
size++; w=size;
tr[w].son=1; tr[w].rnd=rand(); tr[w].v=tr[w].sum=value; tr[w].wt=1;
return;
}
if(tr[w].v==value) {tr[w].wt++; update(w);}
else if(tr[w].v<value) {
insert(tr[w].r,value); update(w);
if(tr[tr[w].r].rnd<tr[w].rnd) lturn(w);
}
else {
insert(tr[w].l,value); update(w);
if(tr[tr[w].l].rnd<tr[w].rnd) rturn(w);
}
}
void del(int &w,int value){
if(tr[w].v==value){
if(tr[w].wt>1) {tr[w].wt--; update(w); return;}
if(tr[w].l*tr[w].r==0) w=tr[w].l+tr[w].r;
else if(tr[tr[w].l].rnd<tr[tr[w].r].rnd) {rturn(w); del(tr[w].r,value);}
else {lturn(w); del(tr[w].l,value);}
update(w); return ;
}
if(tr[w].v>value) del(tr[w].l,value);
else del(tr[w].r,value);
update(w);
}
int find(int &w,int rank){
if(rank+tr[tr[w].l].son<po&&rank+tr[tr[w].l].son+tr[w].wt>=po){
smin=smin+tr[tr[w].l].sum+(po-rank-tr[tr[w].l].son)*tr[w].v;
return tr[w].v;
}
else if(rank+tr[tr[w].l].son>=po) return find(tr[w].l,rank);
else {
smin=smin+tr[tr[w].l].sum+tr[w].wt*tr[w].v;
return find(tr[w].r,rank+tr[tr[w].l].son+tr[w].wt);
}
}
void work(){
ll mid,ans;
int i;
po=(m+1)>>1;
mid=find(root,0); ans=(mid*po-smin)+(tr[root].sum-smin-mid*(m-po));
for(i=m+1;i<=n;i++){
del(root,h[i-m]); insert(root,h[i]);
smin=0; mid=find(root,0);
ans=min(ans,(mid*po-smin)+(tr[root].sum-smin-mid*(m-po)));
}
printf("%I64d\n",ans);
}
void init(){
scanf("%d%d",&n,&m);
int i;
for(i=1;i<=n;i++) scanf("%d",&h[i]);
for(i=1;i<=m;i++) insert(root,h[i]);
}
int main()
{
init(); work();
return 0;
}