【矩阵优化DP】Subsequence Count

传送门:HDU6155


思路:

  • 计算某个串的子序列个数可以考虑把所有子序列建成一棵字典树,维护一下向左扩展边数,向右扩展边数,节点数
  • 推出方程:
    • $\huge f_{i,a_i\hat{}1}=f_{i-1,a_i}+f_{i-1,a_i\hat{}1}$
    • $\huge f_{i,a_i}=f_{i-1,a_i}$
    • $\huge f_{i,2}=f_{i-1,2}+f_{i-1,a_i}$
  • 然后推一下矩阵;
  • 由于涉及到翻转,维护每个区间矩阵对应的翻转矩阵(或者可以强行交换矩阵的元素)

注意:

  • 注意常数优化;
  • 矩阵不要推错。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define LL long long
#define N 300000
#define mod 1000000007
LL t,n,m,op,x,y,a[N];
bool tag[N];
struct matrix{
LL s[3][3];
matrix operator * (const matrix &mm) const{
matrix ret={0,0,0,0,0,0,0,0,1};
for (register int i=0;i<=1;i++)
for (register int j=0;j<=2;j++)
for (register int k=0;k<=2;k++) ret.s[i][j]=(ret.s[i][j]+s[i][k]*mm.s[k][j])%mod;
return ret;
}
}ans,seg[N],Rseg[N];
const matrix mtx[3]={{1,1,1,0,1,0,0,0,1},{1,0,0,1,1,1,0,0,1},{1,0,0,0,1,0,0,0,1}};
void update(int cur){seg[cur]=seg[cur<<1]*seg[cur<<1|1]; Rseg[cur]=Rseg[cur<<1]*Rseg[cur<<1|1];}
void pushdown(int cur){
if (tag[cur]){
tag[cur<<1]^=1; swap(seg[cur<<1],Rseg[cur<<1]);
tag[cur<<1|1]^=1; swap(seg[cur<<1|1],Rseg[cur<<1|1]);
tag[cur]=0;
}
}
void build(int cur,int l,int r){
if (l==r){seg[cur]=mtx[a[l]]; Rseg[cur]=mtx[a[l]^1]; return;}
int mid=(l+r)>>1; build(cur<<1,l,mid); build(cur<<1|1,mid+1,r); update(cur);
}
void rev(int cur,int l,int r,int L,int R){
if (l>=L && r<=R){tag[cur]^=1; swap(seg[cur],Rseg[cur]); return;}
pushdown(cur);
int mid=(l+r)>>1; if (L<=mid) rev(cur<<1,l,mid,L,R); if (R>mid) rev(cur<<1|1,mid+1,r,L,R);
update(cur);
}
void mul(int cur,int l,int r,int L,int R){
if (l>=L && r<=R){ans=ans*seg[cur]; return;}
pushdown(cur); int mid=(l+r)>>1;
if (L<=mid) mul(cur<<1,l,mid,L,R); if (R>mid) mul(cur<<1|1,mid+1,r,L,R);
}
int main(){
scanf("%lld",&t);
while (t--){
scanf("%lld%lld",&n,&m); memset(tag,0,sizeof tag);
for (int i=1;i<=n;i++){for (a[i]=getchar();a[i]!='0' && a[i]!='1';a[i]=getchar()); a[i]-='0';}
build(1,1,n);
for (int i=1;i<=m;i++){
scanf("%lld%lld%lld",&op,&x,&y);
if (op==1) rev(1,1,n,x,y);
if (op==2){ans=mtx[2]; mul(1,1,n,x,y); printf("%lld\n",(ans.s[0][2]+ans.s[1][2])%mod);}
}
}
return 0;
}