天天看點

牛客網 檸檬樹

https://ac.nowcoder.com/acm/problem/212478

線段樹 維護區間最近公共祖先

樹狀數組維護顔色資訊字首和

lct動态加點維護集合資訊

然後對查詢按右端點排序,離線查詢,最後輸出

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6+10;
int h[N], e[N], nxt[N], idx;
int t[N], dep[N], fa[N][20];
int c[N], ans[N];
int n, m;
struct Query{
    int id, l, r;
    bool operator<(const Query &b)const{
        return r < b.r;
    }
}q[N];
int lowbit(int x){ return -x & x;}
void ta_add(int x, int v){
    x ++;
    for(; x < N; x += lowbit(x)) c[x] += v;
}
int sum(int x){
    x ++;
    int res = 0;
    for(; x; x -= lowbit(x)) res += c[x];
    return res;
}
void add(int a, int b){
    e[idx] = b, nxt[idx] = h[a], h[a] = idx ++;
}
void lca_bfs(){
    memset(dep, 0x3f, sizeof dep);
    dep[0] = 0; dep[1] = 1;
    queue<int> q; q.push(1);
    while(q.size()){
        int u = q.front(); q.pop();
        for(int i = h[u]; ~i; i = nxt[i]){
            int to = e[i];
            if(dep[to] > dep[u] + 1){
                dep[to] = dep[u] + 1;
                q.push(to);
                fa[to][0] = u;
                for(int k = 1; k < 20; ++ k)
                    fa[to][k] = fa[fa[to][k - 1]][k - 1];
            }
        }
    }
}
int lca(int a, int b){
    if(dep[a] < dep[b]) swap(a, b);
    for(int k = 19; k >= 0; k --) if(dep[fa[a][k]] >= dep[b]) a = fa[a][k];
    if(a == b) return a;
    for(int k = 19; k >= 0; k --) if(fa[a][k] != fa[b][k]) a = fa[a][k], b = fa[b][k];
    return fa[a][0];
}
int query(int l, int r, int u = 1, int L = 1, int R = n){
    if(l <= L && R <= r) return t[u];
    int mid = L + R >> 1;
    if(r <= mid) return query(l, r, u << 1, L, mid);
    if(l > mid) return query(l, r, u << 1 | 1, mid + 1, R);
    return lca(query(l, r, u << 1, L, mid), query(l, r, u << 1 | 1, mid + 1, R));
}
void build(int l, int r,int u){
    t[u] = l;
    if(l == r) return;
    int mid = l + r >> 1;
    build(l, mid, u << 1); build(mid + 1, r, u << 1 | 1);
    t[u] = lca(t[u << 1], t[u << 1 | 1]);
}

namespace LCT{
    struct Node{int s[2], v, p, siz, tag;}t[N];
    int stk[N];
    void dfs(int u, int father){
        t[u].p = father;
        for(int i = h[u]; ~i; i = nxt[i])
            if(e[i] != father) dfs(e[i],u);
    }
    void pushtag(int u, int col){
        t[u].v = t[u].tag = col;
    }
    void pushup(int u){
        t[u].siz = t[t[u].s[0]].siz + t[t[u].s[1]].siz + 1;
    }
    bool isroot(int x)
    {
        return t[t[x].p].s[0] != x && t[t[x].p].s[1] != x;
    }
    void pushdown(int u){
        if(t[u].tag){
            pushtag(t[u].s[0], t[u].tag); pushtag(t[u].s[1], t[u].tag);
            t[u].tag = 0;
        }
    }
    void rotate(int x){
        int y = t[x].p, z = t[y].p, k = t[y].s[1] == x;
        if(!isroot(y)) t[z].s[t[z].s[1] == y] = x;
        t[x].p = z;
        t[y].s[k] = t[x].s[k ^ 1]; t[t[x].s[k ^ 1]].p = y;
        t[x].s[k ^ 1] = y; t[y].p = x;
        pushup(y); pushup(x);
    }
    void splay(int x){
        int z = x, top = 0;
        stk[++ top] = x;
        while(!isroot(x)) stk[++ top] = x = t[x].p;
        while(top) pushdown(stk[top --]);
        
        x = z;
        while(!isroot(x)){
            int y = t[x].p, z = t[y].p;
            if(!isroot(y))
                if(t[y].s[1] == x ^ t[z].s[1] == y) rotate(x);
                else rotate(y);
            rotate(x);
        }
    }
    void access(int x){
        int col = x, y = 0;
        for(; x; y = x, x = t[x].p){
            splay(x);
            t[x].s[1] = y; pushup(x);
            ta_add(t[x].v, -(t[t[x].s[0]].siz + 1));
            ta_add(col, t[t[x].s[0]].siz + 1);
        }
        pushtag(y, col);
    }
    
}
int main(){
    memset(h, -1, sizeof h);
    scanf("%d%d", &n, &m);
    for(int i = 0; i < n - 1; ++ i){
        int a, b; scanf("%d%d", &a, &b);
        add(a, b); add(b, a);
    }
    for(int i = 0, l, r; i < m; ++ i){
        scanf("%d%d", &l, &r);
        q[i] = {i, l, r};
    }
    sort(q, q + m);
    lca_bfs();
    build(1, n, 1);
    LCT::dfs(1, 0);
    ta_add(0, n);
    for(int i = 1, j = 0; i <= n; ++ i){
        LCT::access(i);
        for(; j < m && q[j].r == i; j ++)
            ans[q[j].id] = n - sum(q[j].l - 1) - dep[query(q[j].l, q[j].r)];
    }
    for(int i = 0; i < m; ++ i) printf("%d\n", ans[i]);
}