Description
給定一棵樹,每個節點有一個權值ai和一個字元串si
q組詢問,每次詢問一個字元串S和兩個節點x,y
求x到y路徑上每個節點的字元串在S中出現的次數乘上各自的權值總和。
有單點修改權值的操作。
n,q<=200000, ∑ s i , ∑ S ≤ 400000 \sum si,\sum S\leq 400000 ∑si,∑S≤400000
強制線上
Solution
首先對于樹上路徑的問題,并且隻是加和減,我們可以考慮括号序。
對于每個節點進入的時候打一個1的标記,出的時候打一個-1的标記
那麼一條路徑 x t o y x\ to\ y x to y,令 p = l c a ( x , y ) p=lca(x,y) p=lca(x,y),那麼這條路徑可以表示為兩個區間和 [ i n [ p ] , i n [ x ] ] , ( i n [ p ] , i n [ y ] ] \left[in[p],in[x]\right],\left(in[p],in[y]\right] [in[p],in[x]],(in[p],in[y]]
這樣就将樹上問題轉化成了序列問題。
對于這個括号序,我們建一個線段樹,線段樹上的每個區間建AC自動機
這樣總的AC自動機大小是 2 n log n 2n\log n 2nlogn的
我們還要支援修改權值,那麼在AC自動機fail樹上修改鍊,相當于子樹修改,直接用樹狀數組維護fail樹的DFS序即可。
總的複雜度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)
出題人還卡了空間
代碼及其惡心(出題人已被群毆,不省人事…)
Code
#pragma GCC optimize(2)
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstdio>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 200005
#define M 800005
#define L 8000005
using namespace std;
int n1,n,m,m1,tp,a1[N][2],t[M][2],pr[N],fs[N],da[M],len,dep[M],n3;
int nt[2*N],dt[2*N],f[N][20];
int t1[L][5],wz[22][M],fi[L],nx[L],d1[L],dfn[L],d[L],fail[L],sz[L],c[L];
int m2,n2,in[N],out[N],rt[M],de[M],c1[26];
char st[M],ch[M];
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
void read(int &x)
{
x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
}
void lk(int x,int y)
{
nx[++m2]=fi[x];
d1[fi[x]=m2]=y;
}
void dfs(int k,int fa)
{
dep[k]=dep[fa]+1;
f[k][0]=fa;
da[in[k]=++da[0]]=k;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa) dfs(p,k);
}
da[out[k]=++da[0]]=k;
}
int lca(int x,int y)
{
if(dep[x]>dep[y]) swap(x,y);
for(int l=dep[y]-dep[x],j=0;l;l>>=1,j++) if(l&1) y=f[y][j];
for(int j=17;x!=y;)
{
while(j&&f[x][j]==f[y][j]) j--;
x=f[x][j],y=f[y][j];
}
return x;
}
void dg(int k)
{
dfn[k]=++dfn[0];
sz[k]=1;
for(int i=fi[k];i;i=nx[i]) dg(d1[i]),sz[k]+=sz[d1[i]];
}
int lowbit(const int &k)
{
return k&(-k);
}
void put(int k,const int &v)
{
while(k<=n3) c[k]+=v,k+=lowbit(k);
}
int get(int k)
{
int s=0;
while(k) s+=c[k],k-=lowbit(k);
return s;
}
void change(int l,int r,int v)
{
put(l,v),put(r+1,-v);
}
void make(int k2,int l,int r)
{
int k1=rt[k2];
fo(i1,l,r)
{
int k=k1,i=da[i1];
fo(j,a1[i][0],a1[i][1])
{
int c=c1[st[j]-'A'];
if(!t1[k][c]) t1[k][c]=++n2;
k=t1[k][c];
}
wz[de[k2]][i1]=k;
}
int x=0,y=1;
d[1]=k1;
while(x<y)
{
int k=d[++x];
fo(c,0,4)
{
if(t1[k][c])
{
d[++y]=t1[k][c];
int p=fail[k];
while(p&&!t1[p][c]) p=fail[p];
if(!p) fail[d[y]]=k1;
else fail[d[y]]=t1[p][c];
lk(fail[d[y]],d[y]);
}
}
}
dg(k1);
fo(i1,l,r)
{
int i=da[i1],i2=wz[de[k2]][i1];
if(i1==in[i]) change(dfn[i2],dfn[i2]+sz[i2]-1,pr[i]);
else change(dfn[i2],dfn[i2]+sz[i2]-1,-pr[i]);
}
}
void build(int k,int l,int r)
{
rt[k]=++n2;
make(k,l,r);
if(l==r) return;
int mid=(l+r)>>1;
t[k][0]=++n1,de[n1]=de[k]+1,build(t[k][0],l,mid);
t[k][1]=++n1,de[n1]=de[k]+1,build(t[k][1],mid+1,r);
}
int gs(int k)
{
int s=0;
fo(i,1,len)
{
int c=c1[ch[i]-'A'];
while(fail[k]&&!t1[k][c]) k=fail[k];
if(t1[k][c]) k=t1[k][c];
s+=get(dfn[k]);
}
return s;
}
int query(int k,int l,int r,int x,int y)
{
if(x>y||!k||y<l||x>r) return 0;
if(x<=l&&r<=y) return gs(rt[k]);
int mid=(l+r)>>1;
return query(t[k][0],l,mid,x,y)+query(t[k][1],mid+1,r,x,y);
}
void ins(int k,int l,int r,int w,int y)
{
change(dfn[wz[de[k]][w]],dfn[wz[de[k]][w]]+sz[wz[de[k]][w]]-1,y);
if(l==r) return;
int mid=(l+r)>>1;
if(w<=mid) ins(t[k][0],l,mid,w,y);
else ins(t[k][1],mid+1,r,w,y);
}
int main()
{
cin>>n>>tp;
int le=0;
bool pd=1;
fo(i,1,n)
{
scanf("\n%s",ch+1);
int l1=strlen(ch+1);
a1[i][0]=le+1;
fo(j,1,l1) st[++le]=ch[j];
a1[i][1]=le;
}
fo(i,1,n) read(pr[i]);
fo(i,1,n-1)
{
int x,y;
if(x>y) swap(x,y);
read(x),read(y);
if(x!=y-1) pd=0;
link(x,y),link(y,x);
}
c1[0]=0;
c1[2]=1;
c1['G'-'A']=2;
c1['T'-'A']=3;
c1['U'-'A']=4;
dfs(1,0);
fo(j,1,18) fo(i,1,n) f[i][j]=f[f[i][j-1]][j-1];
n1=1;
n3=n*27;
build(1,1,2*n);
int q,ans=0;
cin>>q;
le=0;
fo(i,1,q)
{
int ip,x,y;
read(ip),read(x),read(y);
x^=(ans*tp),y^=(ans*tp);
if(ip==1)
{
scanf("%s",ch+1);
len=strlen(ch+1);
int lp=lca(x,y);
ans=query(1,1,2*n,in[lp],in[x])+query(1,1,2*n,in[lp]+1,in[y]);
printf("%d\n",ans);
}
else
{
int v=y-pr[x];
ins(1,1,2*n,in[x],v),ins(1,1,2*n,out[x],-v);
pr[x]=y;
}
}
}