天天看點

[SCOI2018]Numazu 的蜜柑 [ 二次剩餘 ]

​​傳送門​​

直接解發現  

[SCOI2018]Numazu 的蜜柑 [ 二次剩餘 ]

于是二次剩餘解出  

[SCOI2018]Numazu 的蜜柑 [ 二次剩餘 ]
#include<bits/stdc++.h>
#define LL long long
#define N 200050
using namespace std;
LL read(){
  LL cnt = 0, f = 1; char ch = 0;
  while(!isdigit(ch)){ ch = getchar(); if(ch == '-') f = -1;}
  while(isdigit(ch)) cnt = cnt*10 + (ch-'0'), ch = getchar();
  return cnt * f;
}
int n; LL A, B, p, a[N];
vector<int> v[N];
LL det, w, a1, a2, ans;
LL mul(LL a, LL b){
  LL ans = 0; for(;b;b>>=1){
    if(b&1) ans = (ans+a) % p;
    a = (a+a) % p;
  } return ans;
}
LL power(LL a, LL b){
  LL ans = 1; for(;b;b>>=1){
    if(b&1) ans = mul(ans, a);
    a = mul(a, a);
  } return ans;
}
LL legander(LL x){
  LL a = power(x, (p-1)/2);
  if(a + 1 == p) return -1;
  return a;
}
struct Node{
  LL x, y;
  Node(LL _x=0, LL _y=0){ x = _x, y = _y;}
  Node operator * (const Node &a)
  { return Node((mul(x, a.x) + mul(mul(y, a.y), w))%p, (mul(x, a.y) + mul(y, a.x))%p); } 
};
Node Power(Node a, LL b){
  Node ans = Node(1, 0);
  for(;b;b>>=1){
    if(b&1) ans = ans * a;
    a = a * a;
  } return ans;
}
LL Solve(LL x){
  LL a;
  while(1){
    a = rand() % p;
    w = (mul(a, a) - x + p) % p;
    if(legander(w) == -1) break;
  }
  Node res = Node(a, 1ll);
  res = Power(res, (p+1) / 2);
  return res.x;
}
map<LL, int> cnt;
void dfs1(int u){
  ans += cnt[a[u]];
  LL v1 = mul(a1, a[u]), v2 = mul(a2, a[u]);
  (v1 == v2) ? cnt[v1]++ : (cnt[v1]++, cnt[v2]++);
  for(int i=0; i<v[u].size(); i++) dfs1(v[u][i]);
  (v1 == v2) ? cnt[v1]-- : (cnt[v1]--, cnt[v2]--);
}
LL num;
void dfs2(int u){
  if(a[u] == 0) ans += num, num++;
  for(int i=0; i<v[u].size(); i++) dfs2(v[u][i]);
  if(a[u] == 0) num--;
}
int main(){
  n = read(), p = read(), A = read(), B = read();
  for(int i=1; i<=n; i++) a[i] = read();
  for(int i=2; i<=n; i++) v[read()].push_back(i);
  det = (mul(A, A) - (B * 4) % p + p) % p;
  if(legander(det) != -1){
    det = Solve(det);
    a1 = mul(((det - A + p) % p), power(2, p-2));
    a2 = mul(((-det - A + p + p) % p), power(2, p-2));
    dfs1(1);
  }
  else dfs2(1);
  printf("%lld", ans);
  return 0;
}