【树链剖分】[Mz]树链剖分练习

传送门:CodeVS4633


可以算是树剖模板吧

轻重链剖分:主要思路就是按子节点的size划分轻边和重链,然后所有相邻的重链上的点dfs序是连着的

于是自根向下每经过一条轻边size减半,两条轻边间只能有一条重链(两条轻边在一个链上)

于是成功吧树转换成序列,用线段树等数据结构乱搞

讲的似乎有点抽象


看看代码吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
int n,m,x,y,op,tot,tOt;
int sum[300030],cov[300030];
int deep[300030],size[300030],top[300030],fa[300030],head[300030],dfn[300030];
//deep存节点深度,size存以该节点为根的子树大小,top存所在重链的最上端(轻边则是本身),fa存父亲节点,dfn存dfs序(dfs遍历顺序,由于一段重链dfn序连续,因此可以用这个找到对应在序列上的区间,核心啊)
struct edge{
int v,next;
}e[300030];
void add(int x,int y){
e[++tot].v=y; e[tot].next=head[x]; head[x]=tot;
}
void dfs1(int cur){
if (fa[cur]==0) fa[cur]=cur;
deep[cur]=deep[fa[cur]]+1;
size[cur]=1;
for (int i=head[cur];i;i=e[i].next)
if (fa[cur]!=e[i].v){
fa[e[i].v]=cur;
dfs1(e[i].v);
size[cur]+=size[e[i].v];
}
}
void dfs2(int cur){
if (top[cur]==0) top[cur]=cur;
dfn[cur]=++tOt;
int t=0;
for (int i=head[cur];i;i=e[i].next)
t=(e[i].v!=fa[cur] && size[e[i].v]>size[t])?e[i].v:t;
if (t==0) return;
top[t]=top[cur],dfs2(t);
for (int i=head[cur];i;i=e[i].next)
if (e[i].v!=fa[cur] && e[i].v!=t) dfs2(e[i].v);
}
void update(int cur){
sum[cur]=sum[cur<<1]+sum[cur<<1|1];
}
void pushdown(int cur,int l,int r,int mid){
cov[cur<<1]+=cov[cur]; cov[cur<<1|1]+=cov[cur];
sum[cur<<1]+=cov[cur]*(mid-l+1); sum[cur<<1|1]+=cov[cur]*(r-mid);
cov[cur]=0;
}
void s_modify(int L,int R,int l,int r,int cur){
if (l>=L && r<=R){cov[cur]++; sum[cur]+=r-l+1; return;}
int mid=(l+r)>>1;
pushdown(cur,l,r,mid);
if (L<=mid) s_modify(L,R,l,mid,cur<<1);
if (R>mid) s_modify(L,R,mid+1,r,cur<<1|1);
update(cur);
}
void modify(int x,int y){
for (;top[x]!=top[y];x=fa[top[x]]){
if (deep[top[x]]<deep[top[y]]) swap(x,y); //注意此处比的是top的deep
s_modify(dfn[top[x]],dfn[x],1,n,1);
}
if (deep[x]>deep[y]) swap(x,y);
s_modify(dfn[x],dfn[y],1,n,1);
}
int s_query(int L,int R,int l,int r,int cur){
if (l>=L && r<=R) return sum[cur];
int mid=(l+r)>>1,SUM=0;
pushdown(cur,l,r,mid);
SUM+=(L<=mid)?s_query(L,R,l,mid,cur<<1):0;
SUM+=(R>mid)?s_query(L,R,mid+1,r,cur<<1|1):0;
update(cur);
return SUM;
}
int query(int x,int y){
int SUM=0;
for (;top[x]!=top[y];x=fa[top[x]]){
if (deep[top[x]]<deep[top[y]]) swap(x,y); //注意此处比的是top的deep
SUM+=s_query(dfn[top[x]],dfn[x],1,n,1);
}
if (deep[x]>deep[y]) swap(x,y);
SUM+=s_query(dfn[x],dfn[y],1,n,1);
return SUM;
}
int main(){
scanf("%d",&n);
for (int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs1(1);
dfs2(1);
scanf("%d",&m);
for (int i=1;i<=m;i++){
scanf("%d%d%d",&op,&x,&y);
if (op==1) modify(x,y);
if (op==2) printf("%d\n",query(x,y));
}
return 0;
}