天天看点

Distance in Tree CodeForces - 161D

http://codeforces.com/problemset/problem/161/D

点分治模板题 查询多少点对之间距离恰好为k

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=0x3f3f3f3f;

struct node
{
    int v;
    int next;
};

vector <int> dis;
node edge[100010];
int first[50010],sum[50010],maxx[50010],book[50010];
int n,k,num;

void addedge(int u,int v)
{
    edge[num].v=v;
    edge[num].next=first[u];
    first[u]=num++;
}

void getsize(int cur,int fa)
{
    int i,v;
    sum[cur]=1,maxx[cur]=0;
    for(i=first[cur];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&book[v]==0)
        {
            getsize(v,cur);
            sum[cur]+=sum[v],maxx[cur]=max(maxx[cur],sum[v]);
        }
    }
}

void getroot(int cur,int fa,int tot,int &minn,int &root)
{
    int i,v;
    maxx[cur]=max(maxx[cur],tot-sum[cur]);
    if(minn>maxx[cur])
    {
        minn=maxx[cur];
        root=cur;
    }
    for(i=first[cur];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&book[v]==0) getroot(v,cur,tot,minn,root);
    }
}

void getdis(int cur,int fa,int d)
{
    int i,v;
    dis.push_back(d);
    for(i=first[cur];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&book[v]==0) getdis(v,cur,d+1);
    }
}

ll getnum(int cur,int det)
{
    vector <int> ::iterator x,y;
    ll res;
    int i;
    dis.clear();
    getdis(cur,0,det);
    sort(dis.begin(),dis.end());
    //for(i=0;i<dis.size();i++) printf("%d ",dis[i]);
    //printf("\n");
    res=0;
    for(i=0;i<dis.size();i++)
    {
        if(dis[i]<=k)
        {
            y=upper_bound(dis.begin(),dis.end(),k-dis[i]);
            x=lower_bound(dis.begin(),dis.end(),k-dis[i]);
            res+=y-x;
        }
    }
    //printf("***%d %d %d***\n",cur,dis.size(),res);
    return res;
}

ll dfs(int cur)
{
    ll res;
    int i,v,minn,root;
    getsize(cur,0);
    minn=N;
    getroot(cur,0,sum[cur],minn,root);
    book[root]=1;
    res=getnum(root,0);
    for(i=first[root];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(book[v]==0) res-=getnum(v,1);
    }
    for(i=first[root];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(book[v]==0) res+=dfs(v);
    }
    return res;
}

int main()
{
    int i,u,v;
    scanf("%d%d",&n,&k);
    memset(first,-1,sizeof(first));
    num=0;
    for(i=1;i<=n-1;i++)
    {
        scanf("%d%d",&u,&v);
        addedge(u,v);
        addedge(v,u);
    }
    printf("%lld\n",dfs(1)/2);
    return 0;
}