天天看點

ccc 2016 s3

/* 
把選中的點和它們的lca标記一下,然後考慮标記的點組成的這棵樹。答案是:路徑長度和 * 2 - 直徑
想辦法把選中的點和連接配接它們的點和邊弄出來重建立棵樹
以一個選中的點作為根,然後dfs,對每一個非選中的點,根據子樹是否有選中的點來決定它是否要選
複雜度為 O(n)
*/ 
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
using namespace std;

int n, m;
const int maxn = 100000 + 5;
vector<int> g[maxn];
bool node[maxn], vis[maxn];

bool dfs(int i) {
	vis[i] = true;
	for (int k = 0; k < g[i].size(); k++) {
		int t = g[i][k]; 
		if (!vis[t] && dfs(t)) {
			node[i] = true;
		}
	}
	return node[i];
}

int dist(int i) {
	vis[i] = true;
	int temp = 0;
	for (int k = 0; k < g[i].size(); k ++ ) {
		int t = g[i][k];
		if (!vis[t] && node[t]) {
			temp += dist(t) + 1;
		}
	}
	return temp;
}

struct p {
	int num, dis;
	p(int aa, int bb) : num(aa), dis(bb) {}
};

int diameter(int i) {
	queue<int> q;
	q.push(i);
	vis[i] = true;
	int last;
	while (!q.empty()) {
		int k = q.front(); q.pop(); last = k;
		for (int w = 0; w < g[k].size(); w++) {
			int c_node = g[k][w];
			if (!vis[c_node] && node[c_node]) {
				q.push(c_node); vis[c_node] = true;
			}
		}
	}
	memset(vis, false, sizeof(vis));
	queue<p> Q;
	Q.push(p(last, 0)); vis[last] = true;
	int ans = 0;
	while (!Q.empty()) {
		p k = Q.front(); Q.pop(); ans = k.dis;
		for (int w = 0; w < g[k.num].size(); w++) {
			int c_node = g[k.num][w];
			if (!vis[c_node] && node[c_node]) {
				Q.push(p(c_node, k.dis + 1)); vis[c_node] = true;
			}
		}
	}
	return ans;
}

void solve() {
	memset(vis, false, sizeof(vis));
	int i;
	for (i = 0; i < n; i ++ ) if (node[i]) break;
	dfs(i);
	
	memset(vis, false, sizeof(vis));
	int d1 = dist(i);
	
	memset(vis, false, sizeof(vis));
	int d2 = diameter(i);
	
	printf("%d\n", 2 * d1 - d2);
}

void init() {
	memset(node, false, sizeof(node));
	scanf("%d%d", &n, &m);
	for (int i = 0; i < m; i ++ ) {
		int a; scanf("%d", &a); node[a] = true;
	}
	for (int i = 0; i < n; i ++ ) g[i].clear();
	for (int i = 1; i <= n - 1; i++) {
		int a, b; scanf("%d%d", &a, &b);
		g[a].push_back(b); g[b].push_back(a);
	}
}

int main() {
	init(); solve();
	return 0;
}