天天看點

洛谷 P2486 [SDOI2011]染色 樹鍊剖分

洛谷 P2486 [SDOI2011]染色

題意

有一棵 n n n 個結點的無根樹,每個結點有一個顔色,有 m m m 個操作,操作分兩種:

  • 将結點 a a a 到結點 b b b 路徑上的所有結點顔色都染成 c c c 。
  • 詢問結點 a a a 到結點 b b b 路徑上顔色段的數量。
  • 顔色段的定義是極長的連續相同顔色被認為是一段。例如

    112221

    由三段組成:

    11

    222

    1

解法

樹鍊剖分

  • 樹上路徑操作的修改與詢問,考慮使用樹鍊剖分。
  • 維護區間内顔色段的數量,考慮使用線段樹,維護區間内顔色段的數量和兩端的顔色。區間合并的時候看相鄰的顔色是否一樣。
  • 樹上的路徑分成若幹條鍊,修改的時候直接修改,使用線段樹維護即可。
  • 詢問的時候需要注意兩條鍊的合并,如果兩條鍊相鄰顔色一樣,則答案減一。

代碼

#pragma region
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <vector>
using namespace std;
typedef long long ll;
#define IT set<node>::iterator
#define tr t[root]
#define lson t[root << 1]
#define rson t[root << 1 | 1]
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
namespace fastIO {
#define BUF_SIZE 100000
#define OUT_SIZE 100000
//fread->R
bool IOerror = 0;
//inline char nc(){char ch=getchar();if(ch==-1)IOerror=1;return ch;}
inline char nc() {
    static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE;
    if (p1 == pend) {
        p1 = buf;
        pend = buf + fread(buf, 1, BUF_SIZE, stdin);
        if (pend == p1) {
            IOerror = 1;
            return -1;
        }
    }
    return *p1++;
}
inline bool blank(char ch) { return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t'; }
template <class T>
inline bool R(T &x) {
    bool sign = 0;
    char ch = nc();
    x = 0;
    for (; blank(ch); ch = nc())
        ;
    if (IOerror)
        return false;
    if (ch == '-')
        sign = 1, ch = nc();
    for (; ch >= '0' && ch <= '9'; ch = nc())
        x = x * 10 + ch - '0';
    if (sign)
        x = -x;
    return true;
}
inline bool R(double &x) {
    bool sign = 0;
    char ch = nc();
    x = 0;
    for (; blank(ch); ch = nc())
        ;
    if (IOerror)
        return false;
    if (ch == '-')
        sign = 1, ch = nc();
    for (; ch >= '0' && ch <= '9'; ch = nc())
        x = x * 10 + ch - '0';
    if (ch == '.') {
        double tmp = 1;
        ch = nc();
        for (; ch >= '0' && ch <= '9'; ch = nc())
            tmp /= 10.0, x += tmp * (ch - '0');
    }
    if (sign)
        x = -x;
    return true;
}
inline bool R(char *s) {
    char ch = nc();
    for (; blank(ch); ch = nc())
        ;
    if (IOerror)
        return false;
    for (; !blank(ch) && !IOerror; ch = nc())
        *s++ = ch;
    *s = 0;
    return true;
}
inline bool R(char &c) {
    c = nc();
    if (IOerror) {
        c = -1;
        return false;
    }
    return true;
}
template <class T, class... U>
bool R(T &h, U &... tmp) { return R(h) && R(tmp...); }
#undef OUT_SIZE
#undef BUF_SIZE
};  // namespace fastIO
using namespace fastIO;
template <class T>
void _W(const T &x) { cout << x; }
void _W(const int &x) { printf("%d", x); }
void _W(const int64_t &x) { printf("%lld", x); }
void _W(const double &x) { printf("%.16f", x); }
void _W(const char &x) { putchar(x); }
void _W(const char *x) { printf("%s", x); }
template <class T, class U>
void _W(const pair<T, U> &x) { _W(x.F), putchar(' '), _W(x.S); }
template <class T>
void _W(const vector<T> &x) {
    for (auto i = x.begin(); i != x.end(); _W(*i++))
        if (i != x.cbegin()) putchar(' ');
}
void W() {}
template <class T, class... U>
void W(const T &head, const U &... tail) { _W(head), putchar(sizeof...(tail) ? ' ' : '\n'), W(tail...); }
#pragma endregion
const int maxn = 1e5 + 5;
int n, m;
int a[maxn];
vector<int> g[maxn];
int son[maxn], fa[maxn], dep[maxn], sz[maxn];
int top[maxn], cnt, id[maxn], wt[maxn];
void dfs1(int u, int f, int deep) {
    fa[u] = f, sz[u] = 1, dep[u] = deep;
    for (auto v : g[u]) {
        if (v == f) continue;
        dfs1(v, u, deep + 1);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int topf) {
    top[u] = topf, id[u] = ++cnt, wt[cnt] = a[u];
    if (!son[u]) return;
    dfs2(son[u], topf);
    for (auto v : g[u]) {
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}
struct segtree {
    int l, r, lc, rc, val, tag;
} t[maxn << 2];
int lcol, rcol;
inline void merge(int root) {
    if (rson.val == 0) {
        tr.val = lson.val;
        tr.lc = lson.lc;
        tr.rc = lson.rc;
        return;
    }
    tr.lc = lson.lc, tr.rc = rson.rc;
    tr.val = lson.val + rson.val - (lson.rc == rson.lc);
}
void spread(int root) {
    if (tr.tag) {
        lson.lc = lson.rc = tr.tag;
        rson.lc = rson.rc = tr.tag;
        lson.tag = rson.tag = tr.tag;
        tr.tag = 0;
        lson.val = rson.val = 1;
    }
}
void build(int root, int l, int r) {
    tr.l = l, tr.r = r, tr.tag = 0;
    if (l == r) {
        tr.lc = tr.rc = wt[l];
        tr.val = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build(root << 1, l, mid);
    build(root << 1 | 1, mid + 1, r);
    merge(root);
}
void update(int root, int l, int r, int c) {
    if (l <= tr.l && tr.r <= r) {
        tr.lc = tr.rc = c;
        tr.val = 1;
        tr.tag = c;
        return;
    }
    spread(root);
    int mid = (tr.l + tr.r) >> 1;
    if (l <= mid) update(root << 1, l, r, c);
    if (r > mid) update(root << 1 | 1, l, r, c);
    merge(root);
}
int query(int root, int l, int r) {
    if (l <= tr.l && tr.r <= r) {
        if (l == tr.l) lcol = tr.lc;
        if (r == tr.r) rcol = tr.rc;
        return tr.val;
    }
    spread(root);
    int ans = 0;
    int mid = (tr.l + tr.r) >> 1;
    if (l <= mid) ans += query(root << 1, l, r);
    if (r > mid) ans += query(root << 1 | 1, l, r);
    if (l <= mid && r > mid) ans -= (lson.rc == rson.lc);
    return ans;
}
void updRange(int x, int y, int c) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        update(1, id[top[x]], id[x], c);
        x = fa[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    update(1, id[y], id[x], c);
}
int qRange(int x, int y) {
    int ans = 0;
    int d1 = 0, d2 = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y), swap(d1, d2);
        ans += query(1, id[top[x]], id[x]);
        if (d1 == rcol) --ans;
        d1 = lcol;
        x = fa[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y), swap(d1, d2);
    ans += query(1, id[y], id[x]);
    if (d1 == rcol) --ans;
    if (d2 == lcol) --ans;
    return ans;
}
int main() {
    R(n, m);
    rep(i, 1, n) R(a[i]);
    rep(i, 1, n - 1) {
        int x, y;
        R(x, y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(1, 0, 1);
    dfs2(1, 1);
    build(1, 1, n);
    while (m--) {
        char op[20];
        R(op + 1);
        int x, y, c;
        R(x, y);
        if (op[1] == 'C') {
            R(c);
            updRange(x, y, c);
        } else
            W(qRange(x, y));
    }
}
           

繼續閱讀