【点分治】树的难题

传送门:LOJ2179


思路:对树进行点分治,用线段树维护边数一定时的最大权值,注意要两棵线段树,维护相同和不同的颜色。


代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define N 800000
#define INF 0x3f3f3f3f3f3f3f3fLL
LL n,m,x,y,z,p1,p2,ans,rt,sz[N],mx[N],c[N];
bool vis[N];
struct edge{
LL v,c;
bool operator < (const edge p) const {return c<p.c;}
};
vector<edge> g[N];
struct Segment_Tree{
LL tot,rt,s[N],ls[N],rs[N];
inline LL max(LL x,LL y){return x>y?x:y;}
Segment_Tree(){rt=tot=0; s[0]=-INF;}
void clear(){rt=tot=0; s[0]=-INF;}
void mdf(LL &x,LL l,LL r,LL xb,LL k){
if (x==0){x=++tot; s[x]=-INF; ls[x]=rs[x]=0;}
if (l==r){s[x]=max(s[x],k); return;}
LL mid=(l+r)>>1;
if (xb<=mid) mdf(ls[x],l,mid,xb,k); else mdf(rs[x],mid+1,r,xb,k);
s[x]=max(s[ls[x]],s[rs[x]]);
}
void add(LL x,LL y){mdf(rt,0,n,x,y);}
LL qry(LL x,LL l,LL r,LL L,LL R){
if (x==0) return -INF;
if (l>=L && r<=R) return s[x];
LL mid=(l+r)>>1,tmp=-INF;
if (L<=mid) tmp=max(tmp,qry(ls[x],l,mid,L,R));
if (R>mid) tmp=max(tmp,qry(rs[x],mid+1,r,L,R));
return tmp;
}
LL ask(LL l,LL r){
l=max(l,0);
if (l>r) return -INF;
return qry(rt,0,n,l,r);
}
}s1,s2;
LL getLL(){
char ch; LL sum=0,fh=1;
for (ch=getchar();ch<'0' || ch>'9';ch=getchar()) fh=ch=='-'?-1:1;
for (;ch>='0' && ch<='9';ch=getchar()) sum=sum*10+ch-'0';
return sum*fh;
}
void merge(LL &x1,LL x2){
if (x2==0) return;
if (x1==0){x1=++s1.tot; s1.s[x1]=-INF; s1.ls[x1]=s1.rs[x1]=0;}
s1.s[x1]=max(s1.s[x1],s2.s[x2]);
merge(s1.ls[x1],s2.ls[x2]); merge(s1.rs[x1],s2.rs[x2]);
}
void getroot(LL u,LL fa,LL sum){
sz[u]=1; mx[u]=0;
for (LL i=0,v;i<(LL)g[u].size();i++)
if (!vis[v=g[u][i].v] && v!=fa){
getroot(v,u,sum); sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],sum-sz[u]);
if (mx[u]<mx[rt]) rt=u;
}
void getans(LL u,LL col,LL col2,LL d,LL fa,LL val){
if (d>p2) return;
ans=max(ans,max(s1.ask(p1-d,p2-d),s2.ask(p1-d,p2-d)-c[col2])+val);
for (LL i=0,v;i<(LL)g[u].size();i++)
if (!vis[v=g[u][i].v] && v!=fa)
getans(v,g[u][i].c,col2,d+1,u,val+(g[u][i].c==col?0:c[g[u][i].c]));
}
void update(LL u,LL col,LL d,LL fa,LL val){
if (d>p2) return;
s2.add(d,val);
for (LL i=0,v;i<(LL)g[u].size();i++)
if (!vis[v=g[u][i].v] && v!=fa)
update(v,g[u][i].c,d+1,u,val+(g[u][i].c==col?0:c[g[u][i].c]));
}
void calc(LL u){
LL lstc=0; s1.clear(); s2.clear();
s1.add(0,0);
for (LL i=0,v;i<(LL)g[u].size();i++)
if (!vis[v=g[u][i].v]){
if (g[u][i].c!=lstc){merge(s1.rt,s2.rt); s2.clear();}
getans(v,g[u][i].c,g[u][i].c,1,u,c[g[u][i].c]); update(v,g[u][i].c,1,u,c[g[u][i].c]);
lstc=g[u][i].c;
}
}
void solve(LL u){
vis[u]=1; calc(u);
for (LL i=0,v;i<(LL)g[u].size();i++)
if (!vis[v=g[u][i].v]){
rt=0; getroot(v,0,sz[v]); solve(rt);
}
}
int main(){
n=getLL(); m=getLL(); p1=getLL(); p2=getLL();
for (LL i=1;i<=m;i++) c[i]=getLL();
for (LL i=1;i<n;i++){
x=getLL(); y=getLL(); z=getLL();
g[x].push_back((edge){y,z});
g[y].push_back((edge){x,z});
}
for (LL i=1;i<=n;i++) sort(g[i].begin(),g[i].end());
mx[rt=0]=INF;
ans=-INF; getroot(1,0,n); solve(rt);
printf("%lld\n",ans);
return 0;
}