Description
给定一颗n个结点的无根树,树上的每个点有一个非负整数点权,定义一条路径的价值为路径上的点权和-路径的点权最大值。
给定参数p,我们想知道,有多少不同的树上简单路径,满足它的价值恰好是p的倍数。
注意:单点算作一个路径;u ≠ v时,(u,v)和(v,u)只算一次。
Data Constraint
对所有测试点,我们有:
n≤10^5,p≤10^7,val_i≤10^9
Solution
这是道树分治的题。我们找出重心的位置,每次从重心往四周遍历,找出每条到重心的路径的点权和%p和路径的点权最大值,然后将路径按点权最大值从小大大排序,用个桶维护当前的路径的点权和,每次在桶中查找路径的点权和-路径的点权最大值的数量。由于可能会算重,所以要先重心的每颗子树自己先搞一下,减去重复。
Code
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=e5+,maxn1=e7+;
struct code{
int mx,sum;
}b[maxn];
int first[maxn],last[maxn],next[maxn],a[maxn],size[maxn],mx[maxn];
int n,i,t,j,k,l,m,x,y,z,num,p,ans,ln,s,cnt[maxn1][],bz[maxn],fa[maxn];
void lian(int x,int y){
last[++num]=y;next[num]=first[x];first[x]=num;
}
bool cmp(code x,code y){
return x.mx<y.mx;
}
void dg1(int x,int y){
int t,p=num;size[x]=;mx[x]=;
for (t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]])continue;
b[++num].sum=(b[p].sum+a[last[t]])%m;
b[num].mx=max(a[last[t]],b[p].mx);
dg1(last[t],x);size[x]+=size[last[t]];mx[x]=max(mx[x],size[last[t]]);
}
}
int find(int x,int y){
int t,k;mx[x]=max(mx[x],p-size[x]);
if (mx[x]*2<=p||p==) return x;
for(t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]]) continue;
k=find(last[t],x);
if (k) return k;
}
return ;
}
void dg(int x){
int t,k;
bz[x]=;num=;
for (t=first[x];t;t=next[t]){
if (bz[last[t]]) continue;
k=num+;
b[++num].mx=max(a[x],a[last[t]]);b[num].sum=(a[x]+a[last[t]])%m;
dg1(last[t],);
sort(b+k,b+num+,cmp);
for (i=k;i<=num;i++){
k=((b[i].sum-b[i].mx)%m+m)%m;
if (k) k=m-k;
if (cnt[k][]==last[t]) ans-=cnt[k][];
l=((b[i].sum-a[x])%m+m)%m;
if (cnt[l][]!=last[t]) cnt[l][]=last[t],cnt[l][]=;
cnt[l][]++;
}
}
sort(b+,b+num+,cmp);
for (i=;i<=num;i++){
k=((b[i].sum-b[i].mx)%m+m)%m;
if (k) k=m-k;else ans++;
if (cnt[k][]==x) ans+=cnt[k][];
l=((b[i].sum-a[x])%m+m)%m;
if (cnt[l][]!=x) cnt[l][]=x,cnt[l][]=;
cnt[l][]++;
}
for (t=first[x];t;t=next[t]){
if (bz[last[t]])continue;p=size[last[t]];
k=find(last[t],x);
dg(k);
}
}
int main(){
freopen("path.in","r",stdin);freopen("path.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=;i<n;i++)
scanf("%d%d",&x,&y),lian(x,y),lian(y,x);
for (i=;i<=n;i++)
scanf("%d",&a[i]);num=;
dg1(,);p=n;
k=find(,);
dg(k);
ans+=n;
printf("%d\n",ans);
}