#include<bits/stdc++.h> using namespace std; #define LL long long #define N 800000 #define INF 0x3f3f3f3f3f3f3f3fLL LL n,m,x,y,z,p1,p2,ans,rt,sz[N],mx[N],c[N]; bool vis[N]; struct edge{ LL v,c; bool operator < (const edge p) const {return c<p.c;} }; vector<edge> g[N]; struct Segment_Tree{ LL tot,rt,s[N],ls[N],rs[N]; inline LL max(LL x,LL y){return x>y?x:y;} Segment_Tree(){rt=tot=0; s[0]=-INF;} void clear(){rt=tot=0; s[0]=-INF;} void mdf(LL &x,LL l,LL r,LL xb,LL k){ if (x==0){x=++tot; s[x]=-INF; ls[x]=rs[x]=0;} if (l==r){s[x]=max(s[x],k); return;} LL mid=(l+r)>>1; if (xb<=mid) mdf(ls[x],l,mid,xb,k); else mdf(rs[x],mid+1,r,xb,k); s[x]=max(s[ls[x]],s[rs[x]]); } void add(LL x,LL y){mdf(rt,0,n,x,y);} LL qry(LL x,LL l,LL r,LL L,LL R){ if (x==0) return -INF; if (l>=L && r<=R) return s[x]; LL mid=(l+r)>>1,tmp=-INF; if (L<=mid) tmp=max(tmp,qry(ls[x],l,mid,L,R)); if (R>mid) tmp=max(tmp,qry(rs[x],mid+1,r,L,R)); return tmp; } LL ask(LL l,LL r){ l=max(l,0); if (l>r) return -INF; return qry(rt,0,n,l,r); } }s1,s2; LL getLL(){ char ch; LL sum=0,fh=1; for (ch=getchar();ch<'0' || ch>'9';ch=getchar()) fh=ch=='-'?-1:1; for (;ch>='0' && ch<='9';ch=getchar()) sum=sum*10+ch-'0'; return sum*fh; } void merge(LL &x1,LL x2){ if (x2==0) return; if (x1==0){x1=++s1.tot; s1.s[x1]=-INF; s1.ls[x1]=s1.rs[x1]=0;} s1.s[x1]=max(s1.s[x1],s2.s[x2]); merge(s1.ls[x1],s2.ls[x2]); merge(s1.rs[x1],s2.rs[x2]); } void getroot(LL u,LL fa,LL sum){ sz[u]=1; mx[u]=0; for (LL i=0,v;i<(LL)g[u].size();i++) if (!vis[v=g[u][i].v] && v!=fa){ getroot(v,u,sum); sz[u]+=sz[v]; mx[u]=max(mx[u],sz[v]); } mx[u]=max(mx[u],sum-sz[u]); if (mx[u]<mx[rt]) rt=u; } void getans(LL u,LL col,LL col2,LL d,LL fa,LL val){ if (d>p2) return; ans=max(ans,max(s1.ask(p1-d,p2-d),s2.ask(p1-d,p2-d)-c[col2])+val); for (LL i=0,v;i<(LL)g[u].size();i++) if (!vis[v=g[u][i].v] && v!=fa) getans(v,g[u][i].c,col2,d+1,u,val+(g[u][i].c==col?0:c[g[u][i].c])); } void update(LL u,LL col,LL d,LL fa,LL val){ if (d>p2) return; s2.add(d,val); for (LL i=0,v;i<(LL)g[u].size();i++) if (!vis[v=g[u][i].v] && v!=fa) update(v,g[u][i].c,d+1,u,val+(g[u][i].c==col?0:c[g[u][i].c])); } void calc(LL u){ LL lstc=0; s1.clear(); s2.clear(); s1.add(0,0); for (LL i=0,v;i<(LL)g[u].size();i++) if (!vis[v=g[u][i].v]){ if (g[u][i].c!=lstc){merge(s1.rt,s2.rt); s2.clear();} getans(v,g[u][i].c,g[u][i].c,1,u,c[g[u][i].c]); update(v,g[u][i].c,1,u,c[g[u][i].c]); lstc=g[u][i].c; } } void solve(LL u){ vis[u]=1; calc(u); for (LL i=0,v;i<(LL)g[u].size();i++) if (!vis[v=g[u][i].v]){ rt=0; getroot(v,0,sz[v]); solve(rt); } } int main(){ n=getLL(); m=getLL(); p1=getLL(); p2=getLL(); for (LL i=1;i<=m;i++) c[i]=getLL(); for (LL i=1;i<n;i++){ x=getLL(); y=getLL(); z=getLL(); g[x].push_back((edge){y,z}); g[y].push_back((edge){x,z}); } for (LL i=1;i<=n;i++) sort(g[i].begin(),g[i].end()); mx[rt=0]=INF; ans=-INF; getroot(1,0,n); solve(rt); printf("%lld\n",ans); return 0; }
|