題目大意:
就是現在對于一個n個點的樹 (1 <= n <= 10^5), 其中有m個點出現了惡魔(1 <= m <= n), 給出m個點的編号(p1, p2, ... pm) 1 <= pi <= n, 現在已知這些惡魔出現的原因是有一本惡魔之書造成的, 而且書到所有惡魔出現的點的距離不能超過d (0 <= d <= n - 1), 問這棵樹中有哪些點可能是惡魔之書出現的位置, 輸出可能的位置的數量
大緻思路:
這個題試了兩種解法, 第一種是官方題解的樹狀DP, 另外一種是比較巧妙的轉化成樹的直徑端點相關的問題
細節都寫在代碼注釋裡了, 詳情見代碼吧
代碼如下:
解法一: 樹狀DP解法
Result : Accepted Memory : 7504 KB Time : 216 ms
/*
* Author: Gatevin
* Created Time: 2015/3/4 20:09:20
* File Name: Kotori_Itsuka.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
/*
* 解法一: 樹狀DP
* 首先任意選擇一個點作為樹的根節點, 友善起見我們選擇點1作為根
* 用集合P表示被标記的點的集合
* 用disDown[i][0]表示從點i到位于以i為根的子樹上的被标記的點的距離的最大值
* 用disDown[i][1]表示從點i到位于i為根的子樹上的被标記的點的距離的次大值
* 注意這裡如果disDown[i][0]與disDown[i][1]如果同時存在,
* 那麼産生這兩個值得被标記的點一定在i的兩顆不同的子樹上
* 用disUp[i]表示點i到不位于以i為根的子樹上的P中的點的距離的最大值
* 注意disUp[i]的值如果存在其最短路徑一定經過i的父親節點
* 那麼滿足題意的點藥滿足的條件就是disDown[i][0] <= d && disUp[i] <= d
* disDown[i][0]與disDown[i][1]可以一遍dfs弄出來
* disUp滿足當u是v的父親節點時有
* 如果disDown[u][0] = disDown[v][0] + 1說明造成disDown[u][0]的點位于v所在子樹上
* 此時, disDown[u][1]的來源點一定是v的兄弟所在樹上的點造成的(如果存在的話)
* 那麼disUp[v] = max(disUp[u] + 1, disDown[u][1] + 1)
* 否則的話說明disDown[u][0]來自v以外的兄弟所在子樹, 由于disDown[u][0] >= disDown[u][1]
* 有disUp[v] = max(disUp[u] + 1, disDown[u][0] + 1)
* 于是disUp也可以一遍dfs解決
* 總體複雜度O(n)
*/
const int inf = 0x3f3f3f3f;
int n, m, d;
bool evil[100010];
vector <int> G[100010];
int disDown[100010][2], disUp[100010];
void dfs_disDown(int now, int father)
{
if(evil[now]) disDown[now][0] = disUp[now] = 0;
for(unsigned int i = 0, sz = G[now].size(); i < sz; i++)
{
int nex = G[now][i];
if(nex == father) continue;
dfs_disDown(nex, now);
if(disDown[now][0] < disDown[nex][0] + 1)//記錄第一第二大值
{
disDown[now][1] = disDown[now][0];
disDown[now][0] = disDown[nex][0] + 1;
}
else disDown[now][1] = max(disDown[now][1], disDown[nex][0] + 1);
}
return;
}
void dfs_disUp(int now, int father)
{
for(unsigned int i = 0, sz = G[now].size(); i < sz; i++)
{
int nex = G[now][i];
if(nex == father) continue;
if(disDown[now][0] == disDown[nex][0] + 1)//造成dis[now][0]的來自nex所在子樹
disUp[nex] = max(disUp[now] + 1, disDown[now][1] + 1);//比較第二大值
else disUp[nex] = max(disUp[now] + 1, disDown[now][0] + 1);//隻需要比較第一大值
dfs_disUp(nex, now);
}
return;
}
int main()
{
scanf("%d %d %d", &n, &m, &d);
for(int i = 1; i <= n; i++)
disDown[i][0] = disDown[i][1] = -inf;//注意初始化
fill(disUp, disUp + n + 1, -inf);
int tmp;
while(m--)
scanf("%d", &tmp), evil[tmp] = 1;
int tu, tv;
for(int i = 1; i < n; i++)
{
scanf("%d %d", &tu, &tv);
G[tu].push_back(tv);
G[tv].push_back(tu);
}
dfs_disDown(1, 0);
dfs_disUp(1, 0);
int ans = 0;
for(int i = 1; i <= n; i++)
if(disDown[i][0] <= d && disUp[i] <= d)
ans++;
printf("%d\n", ans);
return 0;
}
解法二: 轉化為到樹的直徑端點距離的問題
Result : Accepted Memory : 5596 KB Time : 186 ms
/*
* Author: Gatevin
* Created Time: 2015/3/4 21:50:07
* File Name: Kotori_Itsuka.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
/*
* 解法二: 求包含所有被标記的點的最小子樹的直徑
* 然後比較其他點直徑的兩個端點的距離即可
* 注意到有這樣一個事實:
* 如果一棵樹T的直徑上的兩個端點分别是A, B
* 且T是樹S的一部分
* 那麼如果S上某個點到A, B的距離不超過D
* 那麼這個點到這棵子樹上的所有點的距離不超過D
* 是以隻需要找出包含所有點P[1~m]的最小的樹T之後
* 判斷其他點到這棵樹直徑上的兩個端點的距離是否 <= d即可
*/
int n, m, d;
vector <int> G[100010];
int d1[100010], d2[100010];
int p[100010];
int bfs(int start)
{
queue <int> Q;
memset(d1, -1, sizeof(d1));
d1[start] = 0;
Q.push(start);
while(!Q.empty())
{
int now = Q.front();
Q.pop();
for(unsigned int i = 0, sz = G[now].size(); i < sz; i++)
if(d1[G[now][i]] == -1)
d1[G[now][i]] = d1[now] + 1, Q.push(G[now][i]);
}
int ret = 1;
for(int i = 1; i <= m; i++)
if(d1[p[i]] > d1[p[ret]]) ret = i;
return p[ret];
}
int main()
{
scanf("%d %d %d", &n, &m, &d);
for(int i = 1; i <= m; i++)
scanf("%d", p + i);
int u, v;
for(int i = 1; i < n; i++)
{
scanf("%d %d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
int A = bfs(p[1]);//A是直徑的一端
int B = bfs(A);//B是直徑的另外一端
for(int i = 1; i <= n; i++)
d2[i] = d1[i];
//d2[i]為點i到A的距離
bfs(B);
//d1[i]為點i到B的距離
int ans = 0;
for(int i = 1; i <= n; i++)
if(d1[i] <= d && d2[i] <= d)
ans++;
printf("%d\n", ans);
return 0;
}