Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
刚好再练下树链剖分的区间修改。我们维护最左边的颜色,最右边的颜色和区间颜色段数即可。合并的时候判断下两端子区间合并位置颜色是否相同,相同的话则把总颜色数-1。
【我的代码总是那么长】
#include<string>
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
struct line
{
int s,t;
int next;
}a[300001];
int edge;
int head[100001];
int fx[100001];
inline void add(int s,int t)
{
a[edge].next=head[s];
head[s]=edge;
a[edge].s=s;
a[edge].t=t;
}
struct tree
{
int l,r;
int ll,rr;
int s;
int tag;
}tr[800001];
int val[100001];
inline int max(int x,int y)
{
if(x>y)
return x;
return y;
}
inline void up(int p)
{
tr[p].s=tr[p*2].s+tr[p*2+1].s;
if(tr[p*2].rr==tr[p*2+1].ll)
tr[p].s--;
tr[p].ll=tr[p*2].ll;
tr[p].rr=tr[p*2+1].rr;
}
inline void down(int p)
{
if(tr[p].tag!=-1)
{
tr[p*2].tag=tr[p].tag;
tr[p*2+1].tag=tr[p].tag;
tr[p*2].ll=tr[p].tag;
tr[p*2+1].ll=tr[p].tag;
tr[p*2].rr=tr[p].tag;
tr[p*2+1].rr=tr[p].tag;
tr[p*2].s=1;
tr[p*2+1].s=1;
tr[p].tag=-1;
}
}
inline void build(int p,int l,int r)
{
tr[p].l=l;
tr[p].r=r;
tr[p].tag=-1;
if(l!=r)
{
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
up(p);
}
else
{
tr[p].ll=val[fx[l]];
tr[p].rr=val[fx[l]];
tr[p].s=1;
}
}
inline void change(int p,int l,int r,int x)
{
if(l<=tr[p].l&&tr[p].r<=r)
{
tr[p].s=1;
tr[p].tag=x;
tr[p].ll=x;
tr[p].rr=x;
}
else
{
down(p);
int mid=(tr[p].l+tr[p].r)/2;
if(l<=mid)
change(p*2,l,r,x);
if(r>mid)
change(p*2+1,l,r,x);
up(p);
}
}
inline tree find(int p,int l,int r)
{
if(l<=tr[p].l&&tr[p].r<=r)
return tr[p];
else
{
down(p);
int mid=(tr[p].l+tr[p].r)/2;
bool flag1=false,flag2=false;
tree x1,x2;
if(l<=mid)
{
x1=find(p*2,l,r);
flag1=true;
}
if(r>mid)
{
x2=find(p*2+1,l,r);
flag2=true;
}
tree x;
if(flag1)
{
if(flag2)
{
x.s=x1.s+x2.s;
if(x1.rr==x2.ll)
x.s--;
x.ll=x1.ll;
x.rr=x2.rr;
}
else
x=x1;
}
else
x=x2;
up(p);
return x;
}
}
int dep[100001],size[100001],son[100001],fa[100001];
int top[100001],w[100001];
int lson[100001]/*最大节点位置*/,mson[100001]/*最大节点值*/;
int tot;
inline void dfs1(int d)
{
int i;
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(t!=fa[d])
{
dep[t]=dep[d]+1;
fa[t]=d;
dfs1(t);
son[d]+=son[t]+1;
if(son[t]>=mson[d])
{
mson[d]=son[t];
lson[d]=t;
}
}
}
}
inline void dfs2(int d)
{
int i;
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(t==lson[d])
{
top[t]=top[d];
tot++;
w[t]=tot;
fx[tot]=t;
dfs2(t);
}
}
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(t!=fa[d]&&t!=lson[d])
{
top[t]=t;
tot++;
w[t]=tot;
fx[tot]=t;
dfs2(t);
}
}
}
tree xx;
inline tree ask(int s,int t)
{
int u=top[s],v=top[t];
tree x1,x2,xt;
x1=xx;
x2=xx;
while(u!=v)
{
if(dep[u]>dep[v])
{
xt=find(1,w[u],w[s]);
if(x1.s!=0)
{
x1.s+=xt.s;
if(x1.ll==xt.rr)
x1.s--;
x1.ll=xt.ll;
}
else
x1=xt;
s=fa[top[s]];
}
else
{
xt=find(1,w[v],w[t]);
if(x2.s!=0)
{
x2.s+=xt.s;
if(x2.ll==xt.rr)
x2.s--;
x2.ll=xt.ll;
}
else
x2=xt;
t=fa[top[t]];
}
u=top[s];
v=top[t];
}
tree x;
if(w[s]<w[t])
{
xt=find(1,w[s],w[t]);
x.s=x1.s+xt.s+x2.s;
if(x1.ll==xt.ll)
x.s--;
if(x2.ll==xt.rr)
x.s--;
x.ll=x1.rr;
x.rr=x2.rr;
}
else
{
xt=find(1,w[t],w[s]);
x.s=x1.s+xt.s+x2.s;
if(x1.ll==xt.rr)
x.s--;
if(x2.ll==xt.ll)
x.s--;
x.ll=x1.rr;
x.rr=x2.rr;
}
return x;
}
inline void cover(int s,int t,int x)
{
int u=top[s],v=top[t];
while(u!=v)
{
if(dep[u]>dep[v])
{
change(1,w[u],w[s],x);
s=fa[top[s]];
}
else
{
change(1,w[v],w[t],x);
t=fa[top[t]];
}
u=top[s];
v=top[t];
}
if(w[s]<w[t])
change(1,w[s],w[t],x);
else
change(1,w[t],w[s],x);
}
int main()
{
// freopen("paint.in","r",stdin);
// freopen("paint.out","w",stdout);
int n,m;
scanf("%d%d",&n,&m);
int i,s,t;
for(i=1;i<=n;i++)
scanf("%d",&val[i]);
for(i=1;i<=n-1;i++)
{
scanf("%d%d",&s,&t);
edge++;
add(s,t);
edge++;
add(t,s);
}
dep[1]=1;
dfs1(1);
top[1]=1;
tot++;
w[1]=tot;
fx[tot]=1;
dfs2(1);
build(1,1,n);
char x;
int st;
for(i=1;i<=m;i++)
{
scanf("%c",&x);
while(x=='\n'||x=='\r')
scanf("%c",&x);
if(x=='Q')
{
scanf("%d%d",&s,&t);
printf("%d\n",ask(s,t).s);
}
else if(x=='C')
{
scanf("%d%d%d",&s,&t,&st);
cover(s,t,st);
}
}
return 0;
}