#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; #define N 6000000 #define LL long long int n,m,x,y,p,np,q,nq,tot,cnt,d[N],head[N],fa[N],c[N][12],col[N],stt[N]; LL ans; struct edge{int v,nxt;}e[N]; void add(int x,int y){e[++tot].v=y; e[tot].nxt=head[x]; head[x]=tot;} int extend(int p,int w){ np=++cnt; stt[np]=stt[p]+1; for (;p && !c[p][w];p=fa[p]) c[p][w]=np; if (!p) fa[np]=1; else{ if (stt[q=c[p][w]]==stt[p]+1) fa[np]=q; else{ stt[nq=++cnt]=stt[p]+1; memcpy(c[nq],c[q],sizeof c[q]); fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p && c[p][w]==q;p=fa[p]) c[p][w]=nq; } } return np; } void dfs(int x,int fa,int p){ int t=extend(p,col[x]); for (int i=head[x];i;i=e[i].nxt) if (e[i].v!=fa) dfs(e[i].v,x,t); } int main(){ scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&col[i]); for (int i=1;i<n;i++){scanf("%d%d",&x,&y); add(x,y); add(y,x); d[x]++; d[y]++;} cnt=1; for (int i=1;i<=n;i++) if (d[i]==1) dfs(i,0,1); for (int i=1;i<=cnt;i++) ans+=stt[i]-stt[fa[i]]; printf("%lld\n",ans); return 0; }
|