Description
bobo has a tree, whose vertices are conveniently labeled by 1,2,…,n. At the very begining, the i-th vertex is assigned with weight w
i.
There are q operations. Each operations are of the following 2 types:
Change the weight of vertex v into x (denoted as "! v x"),
Ask the total weight of vertices whose distance are no more than d away from vertex v (denoted as "? v d").
Note that the distance between vertex u and v is the number of edges on the shortest path between them.
Input
The input consists of several tests. For each tests:
The first line contains n,q (1≤n,q≤10
5). The second line contains n integers w
1,w
2,…,w
n (0≤w
i≤10
4). Each of the following (n - 1) lines contain 2 integers a
i,b
i denoting an edge between vertices a
i and b
i (1≤a
i,b
i≤n). Each of the following q lines contain the operations (1≤v≤n,0≤x≤10
4,0≤d≤n).
Output
For each tests:
For each queries, a single number denotes the total weight.
Sample Input
4 3
1 1 1 1
1 2
2 3
3 4
? 2 1
! 1 0
? 2 1
3 3
1 2 3
1 2
1 3
? 1 0
? 1 1
? 1 2
Sample Output
3
2
1
6
6
#include<map>
#include<vector>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const int maxn = 3e5 + 10;
const int INF = 0x7FFFFFFF;
int N, Q, w[maxn], x, y;
char ch[5];
struct Tree
{
int ft[maxn], nt[maxn], u[maxn], sz;
int mx[maxn], ct[maxn], vis[maxn];
int pre[maxn];
struct point
{
int x, y;
point(int x = 0, int y = 0) :x(x), y(y) {}
bool operator<(const point &a)const { return x < a.x; }
};
vector<int> d[maxn], D[maxn];
vector<point> dis[maxn];
void clear(int n)
{
mx[sz = 0] = INF;
for (int i = 1; i <= n; i++)
{
ft[i] = -1, vis[i] = 0;
d[i].clear(), D[i].clear();
dis[i].clear();
}
}
void AddEdge(int x, int y)
{
u[sz] = y; nt[sz] = ft[x]; ft[x] = sz++;
u[sz] = x; nt[sz] = ft[y]; ft[y] = sz++;
}
int dfs(int x, int fa, int sum)
{
int y = mx[x] = (ct[x] = 1) - 1;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa || vis[u[i]]) continue;
int z = dfs(u[i], x, sum);
ct[x] += ct[u[i]];
mx[x] = max(mx[x], ct[u[i]]);
y = mx[y] < mx[z] ? y : z;
}
mx[x] = max(mx[x], sum - ct[x]);
return mx[x] < mx[y] ? x : y;
}
int getdep(int x, int fa, int dep, int rt)
{
if (rt) dis[rt].push_back(point(x, dep));
int ans = dep;
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa || vis[u[i]]) continue;
ans = max(getdep(u[i], x, dep + 1, rt), ans);
}
return ans;
}
void put(int x, int fa, int dep, vector<int> &p)
{
if (dep)
{
for (int i = dep; i < p.size(); i += low(i)) p[i] += w[x];
}
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa || vis[u[i]]) continue;
put(u[i], x, dep + 1, p);
}
}
int build(int x, int sum, int fa)
{
int y = dfs(x, -1, sum);
pre[y] = fa; vis[y] = 1;
int len = getdep(y, -1, 0, y);
sort(dis[y].begin(), dis[y].end());
for (int i = 0; i <= len; i++) d[y].push_back(0);
put(y, -1, 0, d[y]);
for (int i = ft[y]; i != -1; i = nt[i])
{
if (vis[u[i]]) continue;
int z = build(u[i], ct[u[i]] > ct[y] ? sum - ct[y] : ct[u[i]], y);
len = getdep(u[i], y, 1, 0);
for (int j = 0; j <= len; j++) D[z].push_back(0);
put(u[i], y, 1, D[z]);
}
vis[y] = 0;
return y;
}
int work(int rt, int x, int y)
{
int ans = 0;
if (x == rt)
{
for (int i = min(y, (int)d[rt].size() - 1); i; i -= low(i)) ans += d[rt][i];
ans += w[x];
}
if (pre[x] != -1)
{
int k = dis[pre[x]][lower_bound(dis[pre[x]].begin(), dis[pre[x]].end(), point(rt, 0)) - dis[pre[x]].begin()].y;
if (k <= y) {
ans += w[pre[x]];
for (int i = min(y - k, (int)d[pre[x]].size() - 1); i; i -= low(i)) ans += d[pre[x]][i];
for (int i = min(y - k, (int)D[x].size() - 1); i; i -= low(i)) ans -= D[x][i];
}
ans += work(rt, pre[x], y);
}
return ans;
}
void change(int rt, int x, int y)
{
if (pre[x] == -1) return;
int k = dis[pre[x]][lower_bound(dis[pre[x]].begin(), dis[pre[x]].end(), point(rt, 0)) - dis[pre[x]].begin()].y;
for (int i = k; i < d[pre[x]].size(); i += low(i)) d[pre[x]][i] += y;
for (int i = k; i < D[x].size(); i += low(i)) D[x][i] += y;
change(rt, pre[x], y);
}
}solve;
int main()
{
while (scanf("%d%d", &N, &Q) != EOF)
{
solve.clear(N);
for (int i = 1; i <= N; i++) scanf("%d", &w[i]);
for (int i = 1; i < N; i++)
{
scanf("%d%d", &x, &y);
solve.AddEdge(x, y);
}
solve.build(1, N, -1);
while (Q--)
{
scanf("%s%d%d", ch, &x, &y);
if (ch[0] == '?') printf("%d\n", solve.work(x, x, y));
else
{
solve.change(x, x, y - w[x]);
w[x] = y;
}
}
}
return 0;
}