天天看點

POJ 1741(點分治)

Tree

Time Limit: 1000MS Memory Limit: 30000K
Total Submissions: 33152 Accepted: 11082

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 

Define dist(u,v)=The min distance between node u and v. 

Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 

Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 

The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
           

Sample Output

8
           

Source

[email protected]

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>

using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int INF = 0x3f3f3f3f;
const int maxn = 1e5+9;

/***動态點分治**/

int n,m;

struct rt{
    int v,c,next;
}e[maxn*2];

int head[maxn],tot,cnt;
int sz[maxn],f[maxn],vis[maxn],dep[maxn],a[maxn];

void add(int u,int v,int c){
    e[tot].v=v;
    e[tot].c=c;
    e[tot].next=head[u];
    head[u]=tot++;
}

int root,sum;
int ans;

void get_root(int u,int fa){
    sz[u]=1;f[u]=0;
    for(int i=head[u];i!=-1;i=e[i].next){
        int v=e[i].v;
        if(v==fa||vis[v])continue;
        get_root(v,u);
        sz[u]+=sz[v];
        f[u]=max(f[u],sz[v]);
    }
    f[u]=max(f[u],sum-sz[u]);
    if(f[u]<f[root])root=u;
}

void getdeep(int u,int fa){
    a[++cnt]=dep[u];
    for(int i=head[u];i!=-1;i=e[i].next){
        int v=e[i].v;
        if(v==fa||vis[v])continue;
        dep[v]=dep[u]+e[i].c;
        getdeep(v,u);
    }
}

int calc(int u,int w){
    cnt=0;dep[u]=w;
    getdeep(u,0);
    sort(a+1,a+1+cnt);
    int res=0;
    int l=1,r=cnt;
    while(l<r){
        if(a[l]+a[r]<=m)res+=r-l,l++;
        else r--;
    }
    return res;
}

void solve(int u){
    ans+=calc(u,0);
    vis[u]=1;
    for(int i=head[u];i!=-1;i=e[i].next){
        int v=e[i].v;
        if(vis[v])continue;
        ans-=calc(v,e[i].c);
        sum=sz[v];root=0;
        get_root(v,0);
        solve(root);
    }
}

int main(){
    while(~scanf("%d%d",&n,&m)&&(n+m)){
        for(int i=0;i<=n;i++)head[i]=-1,vis[i]=0;
        tot=0;
        int u,v,c;
        for(int i=1;i<n;i++){
            scanf("%d%d%d",&u,&v,&c);
            add(u,v,c);
            add(v,u,c);
        }
        ans=0;
        root=0; f[0]=sum=n;
        get_root(1,0);
        solve(root);
        printf("%d\n",ans);
    }
    return 0;
}