題目連結
題意:
給定一棵 n 個節點的樹,1 為根。現要将節點 2 ~ n 劃分為 k 塊,使得每一塊與 根節點 形成的最小斯坦納樹的 邊權值 總和最大。
看了題解之後的思考:
題解是:記有向邊 (u, v) 長度為 w[v],以點 v 為根的子樹的節點總數為 sz[v],那麼答案就是 w[v] * min(sz[v], k) 對每個點求個和。
題解的說法是:可以通過構造使得 balabala...
http://bestcoder.hdu.edu.cn/blog/2017-multi-university-training-contest-3-solutions-by-%E6%B4%AA%E5%8D%8E%E6%95%A6/
一開始覺得...這簡直是在開玩笑吧。後來仔細畫畫圖,好像真是這麼回事。原來之前思考的時候一直陷入了一個誤區,至于是什麼,我們之後再說。
現在,先來談一談,究竟如何構造。
首先,入手點是 每條邊 被算了多少次,即每條邊的權值對最終答案的貢獻。(這種看貢獻的想法真的很重要(敲黑闆))
其次,從題解中我們可以發現,這個算了多少次 好像就 簡簡單單地 取了 可能取到的最大值。
然而,這個盡可能的能被允許的最大值 min(sz[v], k) 能保證在每個點都一定能取到嗎?
答案是可以的。
首先,要明确的是,對于以 v 為根的子樹而言,将其劃分的塊數越多,對 w[v] 的貢獻就越大。
下面開始分類讨論。
1.
當除根以外所有節點的 sz 值 都 <= k 時,對于任意一個點 v,可将以 v 為根的子樹劃分為 sz 塊,此時貢獻最大。
對于任意兩個點 v1,v2 而言:
1)若 v2 在 以 v1 為根的子樹内(v1 在以 v2 為根的子樹内 同理),可将 v1 看成一個整體,其中 sz1 個點劃分為 sz1 塊即可
2)若 v2 不在以 v1 為根的子樹内,且 v1 不在以 v2 為根的子樹内,那麼 v1 的 sz1 塊 與 v2 的 sz2 塊就是完全獨立的,可以簡單地将 sz1 中的任一個集合 與 sz2 中的任一個集合合并,隻要保證 sz1 中不同的集合 不與 sz2 中的同一個集合合并, sz2 中不同的集合 也不與 sz1 中的同一個集合合并 即可(其他合并方式也可以,如使 其中一塊 自成一個集合等等,隻要滿足前述條件 并且 合并後的總數 <= k 即可)
反複進行上述操作,最終可得到 sz (sz <= k) 塊,并且其中每一步都取了盡可能的最大值
2.
存在節點 u,其 sz 值 > k,
對于 u 節點與另外一點 v 而言:
1)若 v 在以 u 為根的子樹中,顯然,u 必有 sz 值 <= k 的孩子節點,若 v 不是,則可遞歸地再往下找。
不妨就設其為 v,則 v 的 sz 值 <= k,并且已經被劃分好為 szv 塊。那麼,對于 u 的其他沒有被劃分的孩子節點,隻需将 szu - szv 個節點劃分為 k - szv 塊即可。
2)若 v 不在以 u 為根的子樹中,u 也不在以 v 為根的子樹中,合并方式同 1. 2) 相同
反複進行上述操作,最終可得到 k 塊,并且其中每一步都取了盡可能的最大值
綜上,可以構造出這樣的集合使得答案取到最大值。
反思:
一開始思維局限在了每一個子樹内要劃分出一個集合,壓根沒想到可以把(相對)不相幹的點塞到同一個集合中,主要還是受沒改題面之前要求 最小值 的影響...
再提一提最小值怎麼做吧,學姐給出的做法是,将根節點到其餘每個節點的距離從小到大排個序(包括1到1的距離0),要注意的是,節點有幾個子樹就把這段距離算幾遍,用bfs+貪心(優先隊列),取前 k 個最小的,然後加上樹上本身每條邊的權值,即為答案。
後話:
看了題解後寫了一個 dfs,結果用測試資料測結果卡了...一狠心在 hdu 上交了一遍竟然 AC 了。跑去看了 std,回來乖乖地又寫了一遍 bfs,然後測試資料就也能跑通了...真是......厲害了。第一場的 12 題也是這樣,總都是搞得深度很深來卡遞歸,也是...沒辦法了Orz
(後面的代碼上面 bfs 和 dfs 都有)
AC代碼如下:
#include <bits/stdc++.h>
#define maxn 1000010
int ne[maxn], sz[maxn], w[maxn], n, k, tot, q[maxn], fa[maxn];
bool vis[maxn];
inline min(int a, int b) { return a < b ? a : b; }
typedef long long LL;
struct Edge {
int to, dist, ne;
Edge(int a = 0, int b = 0, int c = 0) : to(a), dist(b), ne(c) {}
}edge[maxn * 2];
void add(int x, int y, int d) {
edge[tot] = Edge(y, d, ne[x]);
ne[x] = tot++;
}
void dfs(int u, int fa) {
sz[u] = 1;
for (int i = ne[u]; i != -1; i = edge[i].ne) {
Edge e = edge[i]; int v = e.to;
if (v == fa) continue;
w[v] = e.dist;
dfs(v, u);
sz[u] += sz[v];
}
}
void bfs(int src) {
memset(vis, 0, sizeof(vis));
int f = 0, r = 0;
q[r++] = src; vis[src] = true;
while (r > f) {
int u = q[f]; sz[u] = 1;
++f;
for (int i = ne[u]; i != -1; i = edge[i].ne) {
Edge e = edge[i]; int v = e.to;
if (vis[v]) continue;
q[r++] = v; vis[v] = true;
fa[v] = u; w[v] = e.dist;
}
}
for (int i = r - 1; i >= 0; --i) sz[fa[q[i]]] += sz[q[i]];
}
void work() {
memset(ne, -1, sizeof(ne));
tot = 0;
for (int i = 1; i < n; ++i) {
int x, y, d;
scanf("%d%d%d", &x, &y, &d);
add(x, y, d);
add(y, x, d);
}
// dfs(1, -1);
bfs(1);
LL ans = 0;
for (int i = 2; i <= n; ++i) {
ans += (LL)w[i] * min(sz[i], k);
}
printf("%lld\n", ans);
}
int main() {
while (scanf("%d%d", &n, &k) != EOF) work();
return 0;
}