http://www.elijahqi.win/archives/3566
對于本題我們顯然有個暴力的做法 可以搞到40分 那就是每一層的時候暴力枚舉權值 然後看右兒子中比我大的有多少個 算出機率和 然後*我這個點的機率即可
題目中有個重要條件 即每個點權值 均不同 那麼就應該考慮線段樹合并了 這題線上段樹合并的時候怎麼辦
設greatr[i]表示右子樹中比i大的機率是duos greatl[i]同理 那麼不妨線段樹合并的時候從大往小合并這樣就可以一路累加下來了
那麼最後的 設P表示這個非葉子點在題目中的定義
#include<queue>
#include<cstdio>
#include<cctype>
#include<algorithm>
#define fi first
#define se second
#define ll long long
#define pa pair<int,int>
#define mp(x,y) make_pair(x,y)
using namespace std;
inline char gc(){
static char now[1<<16],*S,*T;
if (T==S){T=(S=now)+fread(now,1,1<<16,stdin);if (T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(!isdigit(ch)) {if (ch=='-') f=-1;ch=gc();}
while(isdigit(ch)) x=x*10+ch-'0',ch=gc();
return x*f;
}
const int N=3e5+10;
const int mod=998244353;
inline int ksm(ll b,int t){static ll tmp;
for (tmp=1;t;b=b*b%mod,t>>=1) if(t&1) tmp=tmp*b%mod;return tmp;
}
struct node{
int left,right,v,tag;
}tree[N*20];
pa q[N];int c[N][2],num,d[N],n,top,rt[N],ans,greatl,greatr,a[N],p[N];
inline bool cmp(const pa &a,const pa &b){return a.fi<b.fi;}
inline void insert1(int &x,int l,int r,int p){
x=++num;tree[x].tag=1;tree[x].v=1;
if (l==r) return;int mid=l+r>>1;
if (p<=mid) insert1(tree[x].left,l,mid,p);
else insert1(tree[x].right,mid+1,r,p);
}
inline void pushdown(int x){
if (tree[x].tag==1) return;
int l=tree[x].left,r=tree[x].right,tag=tree[x].tag;
if (l) tree[l].tag=(ll)tree[l].tag*tag%mod,
tree[l].v=(ll)tree[l].v*tag%mod;
if (r) tree[r].tag=(ll)tree[r].tag*tag%mod,
tree[r].v=(ll)tree[r].v*tag%mod;tree[x].tag=1;
}
inline int inc(int x,int v){return x+v>=mod?x+v-mod:x+v;}
inline int dec(int x,int v){return x-v<0?x-v+mod:x-v;}
inline int merge(int rt1,int rt2,int p){
if (!rt1&&!rt2) return 0;
if (rt1&&!rt2){
greatl=inc(greatl,tree[rt1].v);
int tmp=dec(inc(greatr,p),2LL*greatr*p%mod);
tree[rt1].tag=(ll)tree[rt1].tag*tmp%mod;
tree[rt1].v=(ll)tree[rt1].v*tmp%mod;
return rt1;
}
if (!rt1&&rt2){
greatr=inc(greatr,tree[rt2].v);
int tmp=dec(inc(greatl,p),2LL*greatl*p%mod);
tree[rt2].tag=(ll)tree[rt2].tag*tmp%mod;
tree[rt2].v=(ll)tree[rt2].v*tmp%mod;
return rt2;
}pushdown(rt1);pushdown(rt2);
tree[rt1].right=merge(tree[rt1].right,tree[rt2].right,p);
tree[rt1].left=merge(tree[rt1].left,tree[rt2].left,p);
int l=tree[rt1].left,r=tree[rt1].right;
tree[rt1].v=inc(tree[l].v,tree[r].v);return rt1;
}
inline void dfs(int x){
if (!d[x]) return;
if (d[x]==1) {dfs(c[x][0]);rt[x]=rt[c[x][0]];return;}
if (d[x]==2) {
dfs(c[x][0]);dfs(c[x][1]);
greatl=0;greatr=0;
rt[x]=merge(rt[c[x][0]],rt[c[x][1]],p[x]);
}
}
inline int sqr(ll x){return x*x%mod;}
inline void get_ans(int x,int l,int r){
if (l==r){
ans=inc(ans,(ll)l*q[l].fi%mod*sqr(tree[x].v)%mod);return;
}int mid=l+r>>1;pushdown(x);
get_ans(tree[x].left,l,mid);get_ans(tree[x].right,mid+1,r);
}
int main(){
freopen("loj2537.in","r",stdin);
n=read();
for (int i=1;i<=n;++i){
int f=read();
if(i==1) continue;++d[f];
if (!c[f][0]) c[f][0]=i;
else c[f][1]=i;
}int inv=ksm(10000,mod-2);
for (int i=1;i<=n;++i){a[i]=read();
if (!d[i]) q[++top]=mp(a[i],i);
else p[i]=(ll)a[i]*inv%mod;
}sort(q+1,q+top+1,cmp);
for (int i=1;i<=top;++i) insert1(rt[q[i].se],1,top,i);
dfs(1);get_ans(rt[1],1,top);
printf("%d\n",ans);
return 0;
}