天天看点

BZOJ 3757 苹果树 树上莫队

题目大意:给出一棵树,问任意两点之间有多少种不同的颜色,一个人可能会有色盲,会将A和B当成一种颜色。

思路:比较裸的树上莫队,写出来之后,很慢,怀疑是分块的缘故,然后果断找了当年比赛的标称交上去,瞬间rk1,大概看了一眼,他好像是直接用DFS序+曼哈顿距离最小生成树搞的,为什么会比分块快?

昨天下午看到这个题之后就一直在研究树上莫队的正确姿势,然后先写了树分块,后来看了很多牛人的SPOJ COT2的题解,后来又和同学探讨了好久才弄明白。

首先先将树分块,然后把区间排序,按照第一权值为左端点所在块的编号,右端点在DFS序中的位置排序,关键是转移。有一种vfk的靠谱一点的方法。对于任意一个状态,在树上表示[l,r]的路径,目前的状态只存{x|x∈[l,r],x != LCA(l,r)}这些点的颜色,这样就大概有两种情况,一种是两条链,没有中间的LCA,或者是一条链,没有顶端的LCA。然后一直这样转移,例如从[l,r]转移到[x,y]的时候,我们只需要暴力从l->x,y->r,注意记录一个标记数组,在转移的时候把路径上的所有点取反。这样转移之后还是{x|x∈[l,r],x != LCA(l,r)}这些点。统计答案的时候将LCA加回来,然后再删掉。

注意:找LCA要倍增!千万别像我以为写不写倍增都是O(n)然后T一晚上。。。

CODE:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define MAX 100100
using namespace std;
 
inline char GetChar()
{
    static const int L = (1 << 15);
    static char buffer[L],*S = buffer,*T = buffer;
    if(S == T) {
        T = (S = buffer) + fread(buffer,1,L,stdin);
        if(S == T)  return EOF;
    }
    return *S++;
}
  
inline int GetInt()
{
    int c;
    while(!isdigit(c = GetChar()));
    int x = c - '0';
    while(isdigit(c = GetChar()))   
        x = (x << 1) + (x << 3) + c - '0';
    return x;
}
 
int block_size,belong[MAX],blocks;
int pos[MAX],cnt;
int size[MAX],root;
 
struct Ask{
    int x,y,_id;
    int mixed,_mixed;
     
    bool operator <(const Ask &a)const {
        if(belong[x] == belong[a.x])    return pos[y] < pos[a.y];
        return belong[x] < belong[a.x];
    }
    void Read(int p) {
        x = GetInt(),y = GetInt();
        mixed = GetInt(),_mixed = GetInt();
        if(pos[x] > pos[y])  swap(x,y);
        _id = p;
    }
}ask[MAX << 1];
 
int points,asks;
int head[MAX],total;
int _next[MAX << 1],aim[MAX << 1];
 
int deep[MAX],father[MAX][20];
 
int src[MAX];
 
int num[MAX],colors;
bool v[MAX];
int ans[MAX];
 
inline void Add(int x,int y)
{
    _next[++total] = head[x];
    aim[total] = y;
    head[x] = total;
}
 
void DFS(int x,int last)
{
    father[x][0] = last;
    pos[x] = ++cnt;
    deep[x] = deep[last] + 1;
    for(int i = head[x]; i; i = _next[i]) {
        if(aim[i] == last)  continue;
        if(size[belong[x]] < block_size)
            ++size[belong[x]],belong[aim[i]] = belong[x];
        else    ++size[++blocks],belong[aim[i]] = blocks;
        DFS(aim[i],x);
    }
}
 
inline void Change(int x,int c)
{
    if(!num[x]) ++num[x],++colors;
    else if(num[x] == 1 && c == -1) --num[x],--colors;
    else    num[x] += c;
}
 
inline int GetLCA(int x,int y)
{
    if(deep[x] < deep[y])    swap(x,y);
    for(int i = 19; ~i; --i)
        if(deep[father[x][i]] >= deep[y])
            x = father[x][i];
    if(x == y)  return x;
    for(int i = 19; ~i; --i)
        if(father[x][i] != father[y][i])
            x = father[x][i],y = father[y][i];
    return father[x][0];
}
 
inline void Work(int x,int y,int lca)
{
    for(; x != lca; x = father[x][0]) {
        Change(src[x],v[x] ? -1:1);
        v[x] ^= 1;
    }
    for(; y != lca; y = father[y][0]) {
        Change(src[y],v[y] ? -1:1);
        v[y] ^= 1;
    }
}
 
inline void Solve(int p)
{
    static int l = root,r = root,lca;
     
    Work(l,ask[p].x,GetLCA(l,ask[p].x));
    Work(r,ask[p].y,GetLCA(r,ask[p].y));
    l = ask[p].x,r = ask[p].y;
    lca = GetLCA(l,r);
    Change(src[lca],1);
    ans[ask[p]._id] = colors;
    if(ask[p].mixed != ask[p]._mixed)
        ans[ask[p]._id] -= num[ask[p].mixed] && num[ask[p]._mixed];
    Change(src[lca],-1);
}
 
inline void SparseTable()
{
    for(int j = 1; j < 20; ++j)
        for(int i = 1; i <= points; ++i)
            father[i][j] = father[father[i][j - 1]][j - 1];
}
 
int main()
{
    //freopen("apple.in","r",stdin);
    //freopen("apple.out","w",stdout);
    cin >> points >> asks;
    for(int i = 1; i <= points; ++i)
        src[i] = GetInt();
    for(int x,y,i = 1; i <= points; ++i) {
        x = GetInt(),y = GetInt();
        Add(x,y),Add(y,x);
    }
    block_size = sqrt(points + 1);
    size[1] = 1;
    belong[0] = 1;
    blocks = 1;
    DFS(0,MAX - 1);
    root = aim[head[0]];
    SparseTable();
 
    for(int i = 1; i <= asks; ++i)
        ask[i].Read(i);
    sort(ask + 1,ask + asks + 1);
     
    for(int i = 1; i <= asks; ++i)
        Solve(i);
    for(int i = 1; i <= asks; ++i)
        printf("%d\n",ans[i]);
    return 0;
}
           

继续阅读