Description
Edward has a tree with n vertices conveniently labeled with 1,2,…,n.
Edward finds a pair of paths on the tree which share no more than k
Note that path from vertex a to b is the same as the path from vertex b to a. An ordered pair means (A, B) is different from (B, A) unlessA is equal to B.
Input
There are multiple test cases. The first line of input contains an integer T
The first line contains two integers n, k (1 ≤ n, k ≤ 88888). Each of the following n - 1 lines contains two integers ai, bi, denoting an edge between vertices ai and bi (1 ≤ ai, bi ≤ n).
The sum of values n
Output
For each case, output a single integer denoting the number of ordered pairs of paths sharing no more than k
Sample Input
1
4 2
1 2
2 3
3 4
Sample Output
path A | paths share 2 vertices with A | total |
1-2-3-4 | 1-2, 2-3, 3-4 | 3 |
1-2-3 | 1-2, 2-3, 2-3-4 | 3 |
2-3-4 | 1-2-3, 2-3, 3-4 | 3 |
1-2 | 1-2, 1-2-3, 1-2-3-4 | 3 |
2-3 | 1-2-3, 1-2-3-4, 2-3, 2-3-4 | 4 |
3-4 | 1-2-3-4, 2-3-4, 3-4 | 3 |
93
Hint
The number of path pairs that shares no common vertex is 30.
The number of path pairs that shares 1 common vertex is 44.
The number of path pairs that shares 2 common vertices is 19.
这种题对于脑细胞的损耗实在有点大,wa了10发才过,实在是有点疲惫。
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef unsigned long long LL;
const int low(int x) { return x&-x; }
const int maxn = 3e5 + 10;
const int INF = 0x7FFFFFFF;
int T, n, m, x, y;
struct Tree
{
int ft[maxn], nt[maxn], u[maxn], sz;
int vis[maxn], mx[maxn], ct[maxn];
LL d[maxn], D[maxn];
void clear(int n)
{
mx[sz = 0] = INF;
for (int i = 1; i <= n; i++) vis[i] = 0, ft[i] = -1;
}
void AddEdge(int x, int y)
{
u[sz] = y; nt[sz] = ft[x]; ft[x] = sz++;
u[sz] = x; nt[sz] = ft[y]; ft[y] = sz++;
}
int dfs(int x, int fa, int sum)
{
int y = mx[x] = (ct[x] = 1) - 1;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (vis[u[i]] || u[i] == fa) continue;
int z = dfs(u[i], x, sum);
ct[x] += ct[u[i]];
mx[x] = max(mx[x], ct[u[i]]);
y = mx[y] < mx[z] ? y : z;
}
mx[x] = max(mx[x], sum - ct[x]);
return mx[x] < mx[y] ? x : y;
}
int getdep(int x, int fa, int dep)
{
int ans = dep;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa || vis[u[i]]) continue;
ans = max(ans, getdep(u[i], x, dep + 1));
}
return ans;
}
LL get(int x, int fa, int dep)
{
LL cnt = 1, ans = 0;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa) continue;
LL y = vis[u[i]] ? mx[u[i]] : get(u[i], x, dep + 1);
cnt += y; ans += y*y;
}
D[dep] += cnt*cnt - ans;
if (dep == 1) return ans;
return cnt;
}
LL find(int x)
{
LL ans = 0, sum = 0, tot = 0;
int len = getdep(x, -1, 1);
if (len + len <= m + 1) return 0;
sum = get(x, -1, 1);
for (int i = 1; i <= len; i++) d[i] = 0;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (vis[u[i]]) continue;
int y = getdep(u[i], x, 2);
for (int j = 2; j <= y; j++) D[j] = 0;
LL z = get(u[i], x, 2);
for (int j = 2; j <= y; j++)
{
LL s = 0;
for (int k = min(m + 1 - j, len); k > 0; k -= low(k)) s += d[k];
ans += D[j] * (tot - s);
}
for (int j = 2; j <= y; j++)
{
for (int k = j; k <= len; k += low(k)) d[k] += D[j];
if (j > m) ans += (((LL)n - z)*((LL)n - z) + z*z - sum)*D[j];
tot += D[j];
}
}
return ans;
}
int dfs(int x,int fa)
{
int cnt=1;
for (int i=ft[x];i!=-1;i=nt[i])
{
if (u[i]==fa) continue;
cnt+=vis[u[i]]?mx[u[i]]:dfs(u[i],x);
}
return cnt;
}
LL work(int x, int sum)
{
int y = dfs(x, -1, sum);
LL ans = find(y); vis[y] = 1;
for (int i = ft[y]; i != -1; i = nt[i])
{
if (vis[u[i]]) continue;
mx[y] = n-dfs(u[i],y);
ans += work(u[i], ct[u[i]] > ct[y] ? sum - ct[y] : ct[u[i]]);
}
return ans;
}
}solve;
int main()
{
scanf("%d", &T);
while (T--)
{
scanf("%d%d", &n, &m);
LL ans = (LL)n * (n + 1) >> 1;
solve.clear(n);
for (int i = 1; i < n; i++)
{
scanf("%d%d", &x, &y);
solve.AddEdge(x, y);
}
printf("%llu\n", ans*ans - solve.work(1, n));
}
return 0;
}