天天看点

BZOJ 1500[NOI2005] 维修数列

Description

BZOJ 1500[NOI2005] 维修数列

Input

输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。

第2行包含N个数字,描述初始时的数列。

以下M行,每行一条命令,格式参见问题描述中的表格。

任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。

插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。

Output

对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。

本题可用Splay来维护这个数列,而每次插入操作都将插入的数列建一棵Splay然后并入原树即可。而查询区间以及各种修改操作则依赖于伸展操作。查询区间[l,r]时,将节点l-1转至树根,将节点r+1转至根节点的右子节点,然后root->ch[1]->ch[0]就是要操作的区间,将标记打在该节点即可,在旋转和查询的时候要注意标记的下传和节点的更新。

特殊的,要查询区间的最大连续子序列和:该操作需要对每个节点维护几个值,该节点及其子树中的最大前缀和,最大后缀和,以及最大子序列和;则在更新的时候该节点的最大子序列和=max(左子树的最大后缀+该节点的值,右子树的最大前缀+该节点的值,左子树的最大子序列,右子树的最大子序列,左子树最大后缀+节点权值+右子树最大前缀,该节点权值);这样进行更新,则可以保证该节点维护值的正确性,最大前缀与后缀的维护与其相似。注意在更新前要先将该节点及其子节点的标记下传;

还有一个问题,那就是怎样保证这样做的正确性,我们在修改区间值和翻转时只是将标记打在了子树的根节点上,那么该节点的信息还是正确的么?我们在建树时是递归进行的,显然每个点的信息都是最新的,并且是从底向上进行更新,正确性显然可以保证;在我们打翻转标记的时候,该子树的最大子序列显然是不变的(子序列翻转后相邻的数字是不变的),最大前缀与后缀只是交换了位置而已,最大前缀变成了后缀,后缀变成了前缀,将两个值交换即可;在修改值时,将标记打在节点上,当需要用到该节点或其子树的信息时,将标记下传,同时更新节点的信息,如果修改后的值为正数,则很显然max_sum=max_pre=max_sur=该子树节点数x修改后的值,若为负数,则等于该值。所以,只要将标记打在要修改的子树,在调用时更新即可;还有,在每次操作完成后,将该节点到根的路径更新,保证信息的正确性;

在操作时还有一个问题:如果l==1或r==n怎么办;显然节点l-1与r+1是不存在的。于是我们另设两个节点,分别加在原树的前面与后面,设成一个极小的负数,这样它就不会对其他节点的信息有影响,因为当查询区间时,只有在区间内的值才有效,很明显这两个新节点是无论如何都不会在区间内的;;而唯一查询全部节点的操作时最大子序列,由于新节点的值为极小值,所以显然不会将它们选中。(注意极小值不要太小,不然更新时相加可能会爆int!)

嗯。。就是这样。

还有注意加个特判,,,如果tot==0直接输出0或直接return,,数据有毒。。

代码如下:

#include<iostream>
#include<cstdio>
#include<climits>
#include<queue>
#define N 
#define INF 
using namespace std;
struct node { 
    int size,num,sum,pre,sur,max_sum,to; bool sa,re; node *ch[],*fa;
    void clean(int s) { fa=ch[]=ch[]=NULL;size=;re=sa=;sum=pre=sur=max_sum=num=s; }
}*head,*null,o[N];
queue<node*> q;int a[],tot,top;
int in() {
    int s=,v=;char c;
    while((c=getchar())<'0'||c>'9') if(c=='-') v=1;s=c-'0';
    while((c=getchar())>='0'&&c<='9') s=s*10+c-'0';
    return v?-s:s;
}
int check(node *now) { return now->fa->ch[]==now?:; }
void push(node *now) {
    if(now==NULL) return;
    push(now->ch[]);push(now->ch[]); q.push(now);now->clean();
}
node *out(node *fa,int num) {
    node *x;if(!q.empty()) x=q.front(),q.pop(); else x=o+tot++;x->clean(num); x->fa=fa; return x;
}
void mark_down(node *now) {
    if(now==NULL) return;
    if(now->re) {
      swap(now->ch[],now->ch[]);now->re=;swap(now->pre,now->sur);
      if(now->ch[]!=NULL) now->ch[]->re^=;
      if(now->ch[]!=NULL) now->ch[]->re^=;
    }
    if(now->sa) {
      now->sa=;
      if(now->ch[]!=NULL) now->ch[]->sa=,now->ch[]->to=now->to;
      if(now->ch[]!=NULL) now->ch[]->sa=,now->ch[]->to=now->to;
      now->num=now->to; now->sum=now->size*now->to;
      now->max_sum=now->sur=now->pre=max(now->to,now->to*now->size);
    }
}
void update(node *now) { 
    if(now==NULL) return;
    mark_down(now);
    if(now->ch[]!=NULL) mark_down(now->ch[]);
    if(now->ch[]!=NULL) mark_down(now->ch[]);
    now->sum=now->num;
    if(now->ch[]!=NULL) now->sum+=now->ch[]->sum;
    if(now->ch[]!=NULL) now->sum+=now->ch[]->sum;
    now->size=;
    if(now->ch[]!=NULL) now->size+=now->ch[]->size;
    if(now->ch[]!=NULL) now->size+=now->ch[]->size;
    now->pre=now->sur=now->max_sum=-INF;
    if(now->ch[]!=NULL) {
      now->pre=max(now->ch[]->pre,now->ch[]->sum+now->num);
      if(now->ch[]!=NULL) now->pre=max(now->pre,now->ch[]->sum+now->num+now->ch[]->pre);
    }
    else {
      now->pre=max(now->pre,now->num);
      if(now->ch[]!=NULL) now->pre=max(now->pre,now->ch[]->pre+now->num);
    }
    if(now->ch[]!=NULL) {
      now->sur=max(now->ch[]->sur,now->ch[]->sum+now->num);
      if(now->ch[]!=NULL) now->sur=max(now->sur,now->ch[]->sum+now->num+now->ch[]->sur);
    }
    else {
      now->sur=max(now->sur,now->num);
      if(now->ch[]!=NULL) now->sur=max(now->sur,now->num+now->ch[]->sur);
    }
    now->max_sum=max(now->max_sum,now->num);
    if(now->ch[]!=NULL&&now->ch[]!=NULL) {
      now->max_sum=max(max(now->ch[]->max_sum,now->ch[]->max_sum),now->max_sum);
      now->max_sum=max(now->max_sum,now->ch[]->pre+now->num+now->ch[]->sur);
    }
    if(now->ch[]!=NULL) now->max_sum=max(max(now->max_sum,now->ch[]->max_sum),now->ch[]->sur+now->num);
    if(now->ch[]!=NULL) now->max_sum=max(max(now->max_sum,now->ch[]->max_sum),now->ch[]->pre+now->num);
}
void rorate(node *now) {
    node *fa=now->fa;int d=check(now),c=check(fa);
    now->fa=fa->fa;fa->fa->ch[c]=now; fa->ch[d]=NULL;
    if(now->ch[d^]!=NULL) now->ch[d^]->fa=fa,fa->ch[d]=now->ch[d^];
    now->ch[d^]=fa;fa->fa=now;
    update(fa);update(now); return;
}
void splay(node *now,node *fa) {
    for(;now->fa!=fa;) 
      if(now->fa->fa==fa) rorate(now); 
      else {
        node *x=now->fa;
        if(check(now)==check(x)) rorate(x),rorate(now);
        else rorate(now),rorate(now);
      } return;
}
node *build(int l,int r,node *fa,int d) {
    if(l>r) return NULL;
    int mid=l+r>>;
    node *now=out(fa,a[mid]);
    now->ch[]=build(l,mid-,now,);
    now->ch[]=build(mid+,r,now,);
    update(now); return now;
}
node *find(int k) {
    node *now=head;mark_down(now); int ti=;
    if(now->ch[]!=NULL) ti=now->ch[]->size+;
    for(;ti!=k;) {
      if(k>ti) now=now->ch[],k-=ti;
      else now=now->ch[];mark_down(now);
      ti=;if(now->ch[]!=NULL) ti=now->ch[]->size+;
    } return now;
}
void insert() {
    int point=in(),n=in();
    for(int i=;i<=n;i++) a[i]=in(); node *root=build(,n,null,);
    splay(head=find(point+),null),splay(find(point+),head);
    head->ch[]->ch[]=root;root->fa=head->ch[];
    update(head->ch[]);update(head); return;
}
void del() {
    int point=in(),tot=in();
    splay(head=find(point),null);splay(find(point+tot+),head);
    push(head->ch[]->ch[]);head->ch[]->ch[]=NULL;
    update(head->ch[]);update(head); return;
}
void make_same() {
    int point=in(),tot=in(),c=in();
    splay(head=find(point),null);splay(find(point+tot+),head);
    head->ch[]->ch[]->sa=;head->ch[]->ch[]->to=c;
    update(head->ch[]);update(head);return;
}
void max_sum() { printf("%d\n",head->max_sum);return; }
void reverse() {
    int point=in(),tot=in();
    splay(head=find(point),null);splay(find(point+tot+),head);
    head->ch[]->ch[]->re^=;update(head->ch[]);update(head);
}
void get_sum() {
    int point=in(),tot=in();if(tot==) { printf("0\n");return; }
    splay(head=find(point),null);splay(find(point+tot+),head);
    update(head->ch[]->ch[]);update(head->ch[]);update(head);
    printf("%d\n",head->ch[]->ch[]->sum);
}
int main() {
    null=out(NULL,-INF);null->ch[]=null->ch[]=null->fa=null;
    null->size=null->num=null->sum=;
    int n=in(),m=in();for(int i=;i<=n;i++) a[i]=in();
    a[]=a[n+]=-INF;head=build(,n+,null,);
    while(m--) {
      char ord[];
      scanf("%s",ord);
      if(ord[]=='I') insert();
      else if(ord[]=='D') del();
      else if(ord[]=='K') make_same();
      else if(ord[]=='X') max_sum();
      else if(ord[]=='R') reverse();
      else get_sum();
    }
}
           

继续阅读