天天看点

【JZOJ5055】【GDOI2017模拟二试4.12】树上路径

Description

给定一颗n个结点的无根树,树上的每个点有一个非负整数点权,定义一条路径的价值为路径上的点权和-路径的点权最大值。

给定参数p,我们想知道,有多少不同的树上简单路径,满足它的价值恰好是p的倍数。

注意:单点算作一个路径;u ≠ v时,(u,v)和(v,u)只算一次。

Data Constraint

对所有测试点,我们有:

n≤10^5,p≤10^7,val_i≤10^9

【JZOJ5055】【GDOI2017模拟二试4.12】树上路径

Solution

这是道树分治的题。我们找出重心的位置,每次从重心往四周遍历,找出每条到重心的路径的点权和%p和路径的点权最大值,然后将路径按点权最大值从小大大排序,用个桶维护当前的路径的点权和,每次在桶中查找路径的点权和-路径的点权最大值的数量。由于可能会算重,所以要先重心的每颗子树自己先搞一下,减去重复。

Code

#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=e5+,maxn1=e7+;
struct code{
    int mx,sum;
}b[maxn];
int first[maxn],last[maxn],next[maxn],a[maxn],size[maxn],mx[maxn];
int n,i,t,j,k,l,m,x,y,z,num,p,ans,ln,s,cnt[maxn1][],bz[maxn],fa[maxn];
void lian(int x,int y){
    last[++num]=y;next[num]=first[x];first[x]=num;
}
bool cmp(code x,code y){
    return x.mx<y.mx;
}
void dg1(int x,int y){
    int t,p=num;size[x]=;mx[x]=;
    for (t=first[x];t;t=next[t]){
        if (last[t]==y || bz[last[t]])continue;
        b[++num].sum=(b[p].sum+a[last[t]])%m;
        b[num].mx=max(a[last[t]],b[p].mx);
        dg1(last[t],x);size[x]+=size[last[t]];mx[x]=max(mx[x],size[last[t]]);
    }
}
int find(int x,int y){
    int t,k;mx[x]=max(mx[x],p-size[x]);
    if (mx[x]*2<=p||p==) return x;
    for(t=first[x];t;t=next[t]){
        if (last[t]==y || bz[last[t]]) continue;
        k=find(last[t],x);
        if (k) return k;
    }
    return ;
}
void dg(int x){
    int t,k;
    bz[x]=;num=;
    for (t=first[x];t;t=next[t]){
        if (bz[last[t]]) continue;
        k=num+;
        b[++num].mx=max(a[x],a[last[t]]);b[num].sum=(a[x]+a[last[t]])%m;
        dg1(last[t],);
        sort(b+k,b+num+,cmp);
        for (i=k;i<=num;i++){
            k=((b[i].sum-b[i].mx)%m+m)%m;
            if (k) k=m-k;
            if (cnt[k][]==last[t]) ans-=cnt[k][];
            l=((b[i].sum-a[x])%m+m)%m;
            if (cnt[l][]!=last[t]) cnt[l][]=last[t],cnt[l][]=;
            cnt[l][]++;
        }
    }
    sort(b+,b+num+,cmp);
    for (i=;i<=num;i++){
        k=((b[i].sum-b[i].mx)%m+m)%m;
        if (k) k=m-k;else ans++;
        if (cnt[k][]==x) ans+=cnt[k][];
        l=((b[i].sum-a[x])%m+m)%m;
        if (cnt[l][]!=x) cnt[l][]=x,cnt[l][]=;
        cnt[l][]++;
    }
    for (t=first[x];t;t=next[t]){
        if (bz[last[t]])continue;p=size[last[t]];
        k=find(last[t],x);
        dg(k);
    }
}
int main(){
    freopen("path.in","r",stdin);freopen("path.out","w",stdout);
    scanf("%d%d",&n,&m);
    for (i=;i<n;i++)
        scanf("%d%d",&x,&y),lian(x,y),lian(y,x);
    for (i=;i<=n;i++)
        scanf("%d",&a[i]);num=;
    dg1(,);p=n;
    k=find(,);
    dg(k);
    ans+=n;
    printf("%d\n",ans);
}
           

继续阅读