天天看点

牛客网 小白赛5 I区间 线段树,差分思想,树状数组

题目链接:​​https://www.nowcoder.com/acm/contest/135/I​​

好像是区间修改查询,想到了线段树,但是空间要开四倍,然后只过了百分之三十的数据,呜呜呜,线段树也是可以空间优化的

有兴趣的同学可以研究下

空间优化

父节点k,左二子2*k,右儿子2*k+1,需要4*n的空间

但并不是所有的叶子节点占用到2n+1——4n

这就造成大量空间浪费

2*n空间表示法:推荐博客:​​http://www.cppblog.com/MatoNo1/archive/2015/05/05/195857.html​​

用dfs序表示做节点下标

父节点k,左儿子k+1,右儿子:k+左儿子区间长度*2,不是父节点下标+父节点区间长度。因为当树不满时,两者不相等

具体实现这里就不再写模板了,就是改改左右儿子的下标

可参考代码: 题目:楼房重建

里面的建树用的2*n空间

线段树代码是这样的

#include <iostream>
#include <cstdio>
#include <cstring>
#define lson l, mid, o << 1
#define rson mid + 1, r, o << 1 | 1
#define maxn 1000010
#define ll long long
using namespace std;
int n,m;
int p;
ll pre[maxn << 2],add[maxn << 2];
 
void Pushup(int o){
  pre[o] = pre[o << 1] + pre[o << 1 | 1];
}
 
void Pushdown(int o, int ans){
  if(add[o]){
    add[o << 1] += add[o];
    add[o << 1 | 1] += add[o];
    pre[o << 1] += add[o] * (ans - (ans >> 1));
    pre[o << 1 | 1] += add[o] * (ans >> 1);
    add[o] = 0;
  }
}
 
void Build(int l,int r,int o){
  if(l == r){
    scanf("%lld",&pre[o]);
    return ;
  }
  ll mid = (l + r) >> 1;
  Build(lson);
  Build(rson);
  Pushup(o);
}
 
void Update(int L,int R,int ans,int l,int r,int o){
  if(L <= l && r <= R){
    pre[o] += (ll)ans * (r - l + 1);    // 注意这里要强制转换成ll
    add[o] += ans;
    return ;
  }
  Pushdown(o,r-l+1);
  int mid = (l + r) >> 1;
  if(L <= mid)Update(L,R,ans,lson);
  if(R > mid)Update(L,R,ans,rson);
  Pushup(o);
}
 
ll Query(int L,int R,int l,int r,int o){
  if(L <= l && r <= R){
    return pre[o];
  }
  Pushdown(o,r-l+1);
  int mid = (l + r) >> 1;
  ll ans = 0;
  if(L <= mid)ans += Query(L,R,lson);
  if(R > mid)ans += Query(L,R,rson);
  return ans;
}
 
int main()
{
  scanf("%d%d",&n,&m);
  Build(1,n,1);
  while(m--){
    scanf("%d",&p);
    if(p == 1){
      int x,y,z;
      scanf("%d%d%d",&x,&y,&z);
      Update(x,y,-z,1,n,1);
    }
    else{
      int x,y,z;
      scanf("%d%d%d",&x,&y,&z);
      Update(x,y,z,1,n,1);
    }
  }
  int l,r;
  scanf("%d%d",&l,&r);
  printf("%lld\n",Query(l,r,1,n,1));
  return 0;
}      

然后就想用树状数组:

从同学那里借鉴的代码是这样的:

#include<bits/stdc++.h>
using namespace std;
  
typedef long long ll;
const int maxn=1000000+100;
  
ll bit[maxn][2];
int ans[maxn];
  
int n,m;
 
template<typename Q>
 
inline void inin(Q &x)
{
    x=0;
    int f=0;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
         
       if(ch=='-')f=1;ch=getchar();
    }
    while(ch>='0'&&ch<='9') x=x*10+(ch-'0'),ch=getchar();
    x=f?-x:x;
}
  
inline int lowbit(int x){
      
    return x&(-x);
}
  
inline void Add(int x,ll v,int ind){
      
    while(x<=n){
          
        bit[x][ind]+=v;
        x+=lowbit(x);
    }
}
  
inline ll getSum(int x,int ind){
      
    ll sum=0;
    while(x>0){
          
        sum+=bit[x][ind];
        x-=lowbit(x);
    }
    return sum;
}
  
int main(){
      
    inin(n);
    inin(m);
    for(int i=1;i<=n;i++){
          
        inin(ans[i]);
    }
    for(int i=1;i<=m;i++){
          
        int q,l,r,p;
        inin(q);
        inin(l);
        inin(r);
        inin(p);
        if(q==1){
              
            p*=-1;
            Add(l,p,0);
            Add(r+1,-p,0);
            Add(l,(ll)l*p,1);
            Add(r+1,(ll)(-r-1)*p,1);
        }
        else{
              
            Add(l,p,0);
            Add(r+1,-p,0);
            Add(l,(ll)l*p,1);
            Add(r+1,(ll)(-r-1)*p,1);
        }
    }
    int l,r;
    inin(l);
    inin(r);
    ll tmp=(r+1)*getSum(r,0)-getSum(r,1);
    tmp-=l*getSum(l-1,0)-getSum(l-1,1);
    for(int i=l;i<=r;i++) tmp+=(ll)ans[i];
    printf("%lld\n",tmp);
}      

最后是差分:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template<typename Q>
void inin(Q &x)
{
    x=0;int f=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
    x=f?-x:x;
}
const ll MAXN=1004567;
ll a[MAXN],v[MAXN];
int main(){
    int n, m, x;
    inin(n),inin(m);
    for (int i = 1; i <= n; ++i) {
        inin(a[i]);
    }
    int l, r, q, i;
    memset(v, 0, sizeof(v));
    for (i = 0; i < m; ++i) {
        inin(q);inin(l);inin(r);inin(x);
        if(q == 1){
            v[l] -= x;
            v[r+1] += x;
        } else{
            v[l] += x;
            v[r+1] -= x;
        }
    }
    inin(l),inin(r);
    ll t = 0;
    ll sum = 0;
    for (i = 1; i <= r; ++i) {
        t += v[i];
        if(i >= l)sum += t + a[i];
    }
    printf("%lld\n", sum);
}