【AC自动机+KMP+数据结构】回忆树

传送门:BZOJ4231


思路:

  • 首先这道题是有向路径并且字母在边上非常难受然而并没有什么办法
  • 考虑一个长度为$len$的询问,发现转弯的最多只有$O(n)$个,所以取出来KMP即可;
  • 考虑所有不转弯的部分,一定是树上一条直的链;
  • 可以先把询问串建成AC自动机;
  • 然后遍历原树,在AC自动机上更新每个串出现的子树,等价于Fail树上到根的路径加,可以化成单点加子树查询;这样就可以实现在任意点查询$pre_u$,其中$pre_u$为$root$到$u$的路径上任意字符串的数量;
  • 然后考虑答案,$edge(u,v)|dep_u<dep_v$的答案可以通过$pre_v-pre_p$实现,其中$v,p$就是该询问的关键点;
  • 只要提前把询问挂到关键点上即可。

注意:

  • 由于是有向路径KMP只要顺着跑一遍;
  • 树状数组for到cnt为止,因为维护的是Fail树的dfs序。

代码如下(好长啊):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include<bits/stdc++.h>
using namespace std;
#define N 320000
#define ff first
#define ss second
#define P(x,y) make_pair(x,y)
#define gt(x) ((x)<=m?(x):((x)-m))
int n,m,x,y,z,cnt,a[N*2],bd[N],c[N][30],bit[N],ans[N],q[N],fa[N];
char s[N],ss[N],sss[N];
vector<pair<int,int> > cc[N];
struct edge{int v,nxt,ch;};
struct graph{
int tot,clk,dep[N],sz[N],head[N],fa[N],dfn[N],fst[N],end[N],top[N],ps[N];
edge e[N*2];
void add(int x,int y,char ch){e[++tot].v=y; e[tot].ch=ch; e[tot].nxt=head[x]; head[x]=tot;}
void dfs1(int u){
dep[u]=dep[fa[u]]+1;
for (int i=head[u],v;i;i=e[i].nxt)
if ((v=e[i].v)!=fa[u]){fa[v]=u; dfs1(v); sz[u]+=sz[v];}
if (!sz[u]) sz[u]++;
}
void dfs2(int u){
fst[u]=end[u]=dfn[u]=++clk; ps[clk]=u;
if (!top[u]) top[u]=u; int t=0;
for (int i=head[u],v;i;i=e[i].nxt)
if ((v=e[i].v)!=fa[u]&&sz[v]>sz[t]) t=v;
if (!t) return; top[t]=top[u]; dfs2(t); end[u]=end[t];
for (int i=head[u],v;i;i=e[i].nxt)
if ((v=e[i].v)!=fa[u]&&v!=t) dfs2(v),end[u]=end[v];
}
int lca(int u,int v){
for (;top[u]!=top[v];u=fa[top[u]]) if (dep[top[u]]<dep[top[v]]) swap(u,v);
return dep[u]<dep[v]?u:v;
}
}g1,g2;
inline int kmp(char ss[],char s[]){
int ret=0,l1=strlen(ss+1),l2=strlen(s+1);
bd[1]=0;
for (int i=2,j=0;i<=l2;i++){
for (j=bd[i-1];j&&s[j+1]!=s[i];j=bd[j]);
if (s[j+1]==s[i]) bd[i]=j+1; else bd[i]=0;
}
for (int i=1,j=0;i<=l1;i++){
for (;j&&(j>=l2||s[j+1]!=ss[i]);j=bd[j]);
if (s[j+1]==ss[i]) j++; if (j==l2) ret++;
}
return ret;
}
inline void add(int x,int k){for (int i=x;i<=cnt;i+=i&(-i)) bit[i]+=k;}
inline int qry(int x){int ret=0; for (int i=x;i;i-=i&(-i)) ret+=bit[i]; return ret;}
void dfs(int u){
for (int i=g1.head[u],v;i;i=g1.e[i].nxt)
if ((v=g1.e[i].v)!=g1.fa[u]){ss[v]=g1.e[i].ch; dfs(v);}
}
void dfs(int u,int now){
int w=ss[u]-'a';
for (;now&&(w<0||!c[now][w]);now=fa[now]); now=c[now][w]; if (!now) now=1;
add(g2.dfn[now],1);
for (int i=0;i<(int)cc[u].size();i++){
int j=cc[u][i].ff,k=cc[u][i].ss;
ans[j]+=k*(qry(g2.end[a[j]])-qry(g2.fst[a[j]]-1));
}
for (int i=g1.head[u],v;i;i=g1.e[i].nxt)
if ((v=g1.e[i].v)!=g1.fa[u]) dfs(v,now);
add(g2.dfn[now],-1);
}
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++){
scanf("%d%d",&x,&y); char ch;
for (ch=getchar();ch==' ';ch=getchar());
g1.add(x,y,ch); g1.add(y,x,ch);
}
g1.dfs1(1); g1.dfs2(1); dfs(1);
cnt=1;
for (int i=1;i<=m;i++){
scanf("%d%d%s",&x,&y,s+1); z=g1.lca(x,y);
cc[x].push_back(P(i+m,1)); cc[y].push_back(P(i,1));
int d1=g1.dep[x],d2=g1.dep[y];
int len=strlen(s+1),now=1;
for (int j=1;j<=len;j++){
int w=s[j]-'a';
if (!c[now][w]) c[now][w]=++cnt;
now=c[now][w];
}
a[i]=now;
now=1;
for (int j=len;j>=1;j--){
int w=s[j]-'a';
if (!c[now][w]) c[now][w]=++cnt;
now=c[now][w];
}
a[i+m]=now;
for (;g1.dep[g1.top[x]]-g1.dep[z]>=len;x=g1.fa[g1.top[x]]);
if (g1.dep[x]-g1.dep[z]>=len) x=g1.ps[g1.dfn[x]-(g1.dep[x]-g1.dep[z]-len+1)];
for (;g1.dep[g1.top[y]]-g1.dep[z]>=len;y=g1.fa[g1.top[y]]);
if (g1.dep[y]-g1.dep[z]>=len) y=g1.ps[g1.dfn[y]-(g1.dep[y]-g1.dep[z]-len+1)];
assert(g1.dep[x]<=d1); assert(g1.dep[y]<=d2);
cc[x].push_back(P(i+m,-1)); cc[y].push_back(P(i,-1));
int tt=g1.dep[x]+g1.dep[y]-g1.dep[z]*2;
for (int j=1;x!=z;j++,x=g1.fa[x]) sss[j]=ss[x];
for (int j=tt;y!=z;j--,y=g1.fa[y]) sss[j]=ss[y];
sss[tt+1]=0; ans[i]+=kmp(sss,s);
}
int tt=0,ww=1; q[1]=1;
while (tt<ww){
int xb=q[++tt];
g2.add(fa[xb],xb,0);
for (int i=0;i<26;i++) if (c[xb][i]){
int tmp=fa[xb];
for (;tmp&&!c[tmp][i];tmp=fa[tmp]);
fa[c[xb][i]]=tmp?c[tmp][i]:1;
q[++ww]=c[xb][i];
}
}
g2.dfs1(1); g2.dfs2(1);
dfs(1,1);
for (int i=1;i<=m;i++) printf("%d\n",ans[i]+ans[i+m]);
return 0;
}