天天看點

bzoj 1036--樹的統計Count 樹鍊剖分+線段樹

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;

const int maxn=30000+1000;
int fa[maxn],son[maxn],size[maxn],du[maxn],top[maxn],Id[maxn],rev[maxn];
int id;
int n;
int val[maxn];
vector<int>G[maxn];
struct Node{
    
    int Max;
    int Sum;
}node[maxn<<2];

void dfs1(int u,int f,int d){
    
    fa[u]=f;
    du[u]=d;
    size[u]=1;
    int len=G[u].size();
    for(int i=0;i<len;i++){
        
        int v=G[u][i];
        if(v==f) continue;
        dfs1(v,u,d+1);
        size[u]+=size[v];
        if(size[son[u]]<size[v]) son[u]=v; 
    }
}

void dfs2(int u,int t){
    
    top[u]=t;
    Id[u]=++id;
    rev[id]=u;
    if(!son[u]) return ;
    dfs2(son[u],t);
    int len=G[u].size();
    for(int i=0;i<len;i++){
        
        int v=G[u][i];
        if(v==fa[u] || v==son[u]) continue;
        dfs2(v,v);
    }
}

void build(int l,int r,int root){
    
    if(l==r){
        
        node[root].Max=val[rev[l]];
        node[root].Sum=val[rev[l]];
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,root<<1);
    build(mid+1,r,root<<1|1);
    node[root].Max=max(node[root<<1].Max,node[root<<1|1].Max);
    node[root].Sum=node[root<<1].Sum+node[root<<1|1].Sum;
}

void Update(int l,int r,int ind,int root,int v){
    
    if(l==r){
        
        node[root].Max=v;
        node[root].Sum=v;
        return ;
    }
    int mid=(l+r)>>1;
    if(mid>=ind) Update(l,mid,ind,root<<1,v);
    else Update(mid+1,r,ind,root<<1|1,v);
    node[root].Max=max(node[root<<1].Max,node[root<<1|1].Max);
    node[root].Sum=node[root<<1].Sum+node[root<<1|1].Sum;
}

int queryMax(int l,int r,int L,int R,int root){
    
    if(L<=l&&R>=r) return node[root].Max;
    int mid=(l+r)>>1;
    if(mid>=R) return queryMax(l,mid,L,R,root<<1);
    else if(mid<L) return queryMax(mid+1,r,L,R,root<<1|1);
    else return max(queryMax(l,mid,L,R,root<<1),queryMax(mid+1,r,L,R,root<<1|1));
}

int querySum(int l,int r,int L,int R,int root){
    
    if(L<=l&&R>=r) return node[root].Sum;
    int mid=(l+r)>>1;
    if(mid>=R) return querySum(l,mid,L,R,root<<1);
    else if(mid<L) return querySum(mid+1,r,L,R,root<<1|1);
    else return querySum(l,mid,L,R,root<<1)+querySum(mid+1,r,L,R,root<<1|1);
}

int getMax(int u,int v){
    
    int Max=-1e9;
    while(top[u]!=top[v]){
        
        if(du[top[u]]<du[top[v]]) swap(u,v);
        Max=max(Max,queryMax(1,n,Id[top[u]],Id[u],1));
        u=fa[top[u]];
    }
    if(du[u]<du[v]) swap(u,v);
    Max=max(Max,queryMax(1,n,Id[v],Id[u],1));
    return Max;
}

int getSum(int u,int v){
    
    int Sum=0;
    while(top[u]!=top[v]){
        
        if(du[top[u]]<du[top[v]]) swap(u,v);
        Sum+=querySum(1,n,Id[top[u]],Id[u],1);
        u=fa[top[u]];
    }
    if(du[u]<du[v]) swap(u,v);
    Sum+=querySum(1,n,Id[v],Id[u],1);
    return Sum;
}

int main(){
    
    scanf("%d",&n);
    for(int i=0;i<n-1;i++){
        
        int u,v;
        scanf("%d%d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for(int i=1;i<=n;i++) scanf("%d",&val[i]);
    dfs1(1,-1,0);
    dfs2(1,1);
    build(1,n,1);
    int q;
    scanf("%d",&q);
    while(q--){
        
        char tmp[100];
        int x,y;
        scanf("%s%d%d",tmp,&x,&y);
        if(tmp[1]=='H') Update(1,n,Id[x],1,y);
        else if(tmp[1]=='M') printf("%d\n",getMax(x,y));
        else printf("%d\n",getSum(x,y));
    }
}