天天看點

HDU 5029 Relief grain 樹鍊剖分 好題

題目:http://acm.hdu.edu.cn/showproblem.php?pid=5029

題意:給定一棵樹,有n個點,有m次操作,先有n - 1行形式如 a b,代表a b之間有邊,然後是m行形式如a b c,代表将從點a到點b路徑上的點都給予一種糧食c。最後輸出每個點的個數最多的糧食類型。

思路:很容易想到樹鍊剖分,但是糧食種類達100000,怎麼用線段樹維護每個點的糧食種類及其數量是個難題,肯定不能為每個點開一個數組去維護。我們可以換一種方式,用線段樹去維護糧食,每個點代表一種糧食,區間[l, r]維護的是區間内某種糧食的最大值。對于m次操作,先把操作處理成線性的,然後對于區間[l, r]加c,可以在l處标記1,代表此點(這個點不是線段樹上的點,而是題中所給樹的點)c糧食數量+1, r + 1處标記-1,代表此點c糧食數量-1,這是因為我們按題目中給定的樹中的點去更新線段樹,走到某個點時,線段樹中維護的就是目前這個點的糧食色種類和數量,當更新到區間[l, r]外時,要把對應的c減去1。然後線段樹根節點中儲存的就是某點的某種糧食的最大數量和其種類,對應輸出即可

總結:真是好題啊。。。期間一直莫名MLE,一臉懵逼,後來我重寫了dfs1和dfs2,就A了,仔細對比重寫之前和之後的代碼,并沒有什麼不一樣啊。。。

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;

const int N = 100010;
typedef pair<int, int> P;
struct edge
{
    int to, next;
}g[N*2];
struct node
{
    int l, r, maxx, id;
}s[N*4];
int head[N], top[N], siz[N], son[N], dep[N], id[N], fat[N], ans[N];
int n, m, cnt, num;
vector<P> arr[N];
void add_edge(int v, int u)
{
    g[cnt].to = u;
    g[cnt].next = head[v];
    head[v] = cnt++;
}
void dfs1(int v, int d, int fa)
{
    dep[v] = d, siz[v] = 1, son[v] = 0, fat[v] = fa;
    for(int i = head[v]; i != -1; i = g[i].next)
    {
        int u = g[i].to;
        if(u != fa)
        {
            dfs1(u, d + 1, v);
            siz[v] += siz[u];
            if(siz[son[v]] < siz[u]) son[v] = u;
        }
    }
}
void dfs2(int v, int tp)
{
    top[v] = tp, id[v] = ++num;
    if(son[v]) dfs2(son[v], tp);
    for(int i = head[v]; i != -1; i = g[i].next)
    {
        int u = g[i].to;
        if(u != fat[v] && u != son[v]) dfs2(u, u);
    }
}
void renew(int v, int u, int c)
{
    int t1 = top[v], t2 = top[u];
    while(t1 != t2)
    {
        if(dep[t1] < dep[t2])
            swap(t1, t2), swap(v, u);
        arr[id[t1]].push_back(P(c, 1));
        arr[id[v]+1].push_back(P(c, -1));
        v = fat[t1], t1 = top[v];
    }
    if(dep[v] > dep[u]) swap(v, u);
    arr[id[v]].push_back(P(c, 1));
    arr[id[u]+1].push_back(P(c, -1));
}
void build(int l, int r, int k)
{
    s[k].l = l, s[k].r = r, s[k].maxx = 0, s[k].id = l;
    if(l == r) return;
    int mid = (l + r) >> 1;
    build(l, mid, k << 1);
    build(mid + 1, r, k << 1|1);
}
void push_up(int k)
{
    if(s[k<<1].maxx < s[k<<1|1].maxx)
        s[k].maxx = s[k<<1|1].maxx, s[k].id = s[k<<1|1].id;
    else
        s[k].maxx = s[k<<1].maxx, s[k].id = s[k<<1].id;
}
void update(int x, int c, int k)
{
    if(s[k].l == s[k].r)
    {
        s[k].maxx += c;
        return;
    }
    int mid = (s[k].l + s[k].r) >> 1;
    if(x <= mid) update(x, c, k << 1);
    else update(x, c, k << 1|1);
    push_up(k);
}
int main ()
{
    int a, b, c;
    while(scanf("%d%d", &n, &m), n || m)
    {
        cnt = num = 0;
        memset(head, -1, sizeof head);
        for(int i = 1; i <= n - 1; i++)
        {
            scanf("%d%d", &a, &b);
            add_edge(a, b);
            add_edge(b, a);
        }
        dfs1(1, 1, 0);
        dfs2(1, 1);
        int tmp = 1;
        for(int i = 0; i < m; i++)
        {
            scanf("%d%d%d", &a, &b, &c);
            renew(a, b, c);
            tmp = max(tmp, c);
        }
        build(1, tmp, 1);
        for(int i = 1; i <= n; i++)
        {
            int len = arr[i].size();
            for(int j = 0; j < len; j++)
                update(arr[i][j].first, arr[i][j].second, 1);
            if(s[1].maxx == 0) ans[i] = 0;
            else ans[i] = s[1].id;
        }
        for(int i = 1; i <= n; i++)
            printf("%d\n", ans[id[i]]);
        for(int i = 1; i <= num + 1; i++)
            arr[i].clear();
    }

    return 0;
}