题解 P3311 【[SDOI2014]数数】AC自动机-dp
wzj423
2018-07-17 15:30:24
看到这道题,很容易想到以下AC自动机上的dp方程:
```cpp
/**
f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数
f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数
f[2][i][j]表示该位以及之前都是前导零,
因而有:
for(int nx=0;nx<10;++nx)
f[0][i][j]->f[0][c[i][nx]][j+1]
for(int nx=0;nx<N[j];++nx)
f[1][i][j]->f[0][c[i][nx]][j+1]
f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1]
*/
```
但这样只能得到80分,会Wa两个点
这是为什么?我们可以看一看数据:
```cpp
0821463255
0383320100380
001153611528
066512243919
0321614037728
0835268196243
(4.in)
```
这时835268196243XXXXX在AC自动机上的路径为0-8-3-5...
这本来是一组可行解,但因为错误地考虑了前导零而排除了
因而我们需要重新设计方程:
```cpp
/**
f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数
f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数
f[2][i][j]表示该位以及之前都是前导零, AC自动机节点为i,匹配到了第j位的方案数
则有转移:
for(int nx=0;nx<10;++nx)
f[0][i][j]->f[0][c[i][nx]][j+1]
for(int nx=0;nx<N[j];++nx)
f[1][i][j]->f[0][c[i][nx]][j+1]
f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1]
f[1][0][0]->f[2][0][1] nx==0
f[1][0][0]->f[0][c[0][nx]][1] nx!=0
f[2][0][j]->f[2][0][j+1] nx==0
f[2][0][j]->f[0][c[0][i][j+1]
*/
```
这样就能AC了
80pts Ver:
```cpp
#include <bits/stdc++.h>
/*80*/
using namespace std;
//defs==============================
const int mod=1e9+7,MAXP=1700,MAXN=1300;
int M;
char N[MAXN],p[MAXN];int Nlen;
long long ans;
bool tot0;
//AC==========================================
int c[MAXP][10],val[MAXP],fail[MAXP],cnt,_rt=0;
bool is_tot0(char *s) {
int len=strlen(s);
for(int i=0;i<len;++i) if(s[i]!='0') return false;
return true;
}
void ins(char *s) {
int len=strlen(s),now=_rt;
for(int i=0;i<len;++i) {
int v=s[i]-'0';
if(!c[now][v]) c[now][v]=++cnt;
now=c[now][v];
}
++val[now];
}
void build_fail() {
queue<int> q;
for(int i=0;i<10;++i) if(c[0][i]) q.push(c[0][i]);
while(!q.empty()) {
int u=q.front();q.pop();
for(int i=0;i<10;++i)
if(c[u][i]) fail[c[u][i]]=c[fail[u]][i],val[c[u][i]]+=val[fail[c[u][i]]],q.push(c[u][i]);
else c[u][i]=c[fail[u]][i];
}
}
//dp======================
int f[2][MAXP][MAXN];
/*
f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数
f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数
f[2][i][j]表示该位以及之前都是前导零,
for(int nx=0;nx<10;++nx)
f[0][i][j]->f[0][c[i][nx]][j+1]
for(int nx=0;nx<N[j];++nx)
f[1][i][j]->f[0][c[i][nx]][j+1]
f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1]
f[1][0][0]->f[2][0][1] nx==0
f[1][0][0]->f[0][c[0][nx]][1] nx!=0
f[2][0][j]->f[2][0][j+1] nx==0
f[2][0][j]->f[0][c[0][i][j+1]
*/
int vis[2][MAXP][MAXN];
void pp(int i,int j,int k) {
// printf("f[%d][%d][%d]=%d\n",i,j,k,f[i][j][k]);
}
//main=================================
int main() {
scanf("%s",N+1);Nlen=strlen(N+1);
scanf("%d",&M);
for(int i=1;i<=M;++i) scanf("%s",p),ins(p),tot0|=is_tot0(p);
build_fail();
//for(int i=0;i<=cnt;++i)
// printf("id=%d 0=%d 1=%d 2=%d 3=%d 4=%d val=%d fail=%d\n",i,c[i][0],c[i][1],c[i][2],c[i][3],c[i][4],val[i],fail[i]);
vis[1][0][0]=f[1][0][0]=1;
for(int j=0;j<Nlen;++j){
for(int i=0;i<=cnt;++i) {
for(int limit=0;limit<=1;++limit) {
if(!vis[limit][i][j]) continue;
// printf("from (%d,%d,%d)\n",limit,i,j);
if(limit==0) {
//puts("limit=0");
for(int nx=0;nx<10;++nx)
if(!val[c[i][nx]])
(f[0][c[i][nx]][j+1]+=f[0][i][j])%=mod,vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1);
} else {
//puts("limit=1");
for(int nx=0;nx<N[j+1]-'0';++nx)
if( /*!(j==0&&nx==0) &&*/ !val[c[i][nx]])
(f[0][c[i][nx]][j+1]+=f[1][i][j])%=mod,vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1);
if(!val[c[i][ N[j+1]-'0' ]]) {
(f[1][c[i][ N[j+1]-'0' ]][j+1]+=f[1][i][j])%=mod;
vis[1][ c[i][ N[j+1]-'0' ] ][j+1]=true;
//puts("special");
pp(1,c[i][N[j+1]-'0'],j+1);
}
}
}
}
}
for(int i=0;i<=cnt;++i) {
(ans+=f[0][i][Nlen])%=mod;
(ans+=f[1][i][Nlen])%=mod;
}
if(tot0)
printf("%lld\n",ans);
else
printf("%lld\n",ans-1);
return 0;
}
```
100pts Ver
```cpp
#include <bits/stdc++.h>
using namespace std;
//defs==============================
const int mod=1e9+7,MAXP=1700,MAXN=1300;
int M;
char N[MAXN],p[MAXN];int Nlen;
long long ans;
bool tot0;
//AC==========================================
int c[MAXP][10],val[MAXP],fail[MAXP],cnt,_rt=0;
bool is_tot0(char *s) {
int len=strlen(s);
for(int i=0;i<len;++i) if(s[i]!='0') return false;
return true;
}
void ins(char *s) {
int len=strlen(s),now=_rt;
for(int i=0;i<len;++i) {
int v=s[i]-'0';
if(!c[now][v]) c[now][v]=++cnt;
now=c[now][v];
}
++val[now];
}
void build_fail() {
queue<int> q;
for(int i=0;i<10;++i) if(c[0][i]) q.push(c[0][i]);
while(!q.empty()) {
int u=q.front();q.pop();
for(int i=0;i<10;++i)
if(c[u][i]) fail[c[u][i]]=c[fail[u]][i],val[c[u][i]]+=val[fail[c[u][i]]],q.push(c[u][i]);
else c[u][i]=c[fail[u]][i];
}
}
//dp======================
int f[3][MAXP][MAXN];
/*
f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数
f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数
f[2][i][j]表示该位以及之前都是前导零,
for(int nx=0;nx<10;++nx)
f[0][i][j]->f[0][c[i][nx]][j+1]
for(int nx=0;nx<N[j];++nx)
f[1][i][j]->f[0][c[i][nx]][j+1]
f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1]
f[1][0][0]->f[2][0][1] nx==0
f[1][0][0]->f[0][c[0][nx]][1] nx!=0
f[2][0][j]->f[2][0][j+1] nx==0
f[2][0][j]->f[0][c[0][i]][j+1] nx!=0
*/
int vis[3][MAXP][MAXN];
void pp(int i,int j,int k) {
//printf("f[%d][%d][%d]=%d\n",i,j,k,f[i][j][k]);
}
inline void add(int &a,int b) {
(a+=b)%=mod;
}
//main=================================
int main() {
scanf("%s",N+1);Nlen=strlen(N+1);
scanf("%d",&M);
for(int i=1;i<=M;++i) scanf("%s",p),ins(p),tot0|=is_tot0(p);
build_fail();
//for(int i=0;i<=cnt;++i)
// printf("id=%d 0=%d 1=%d 2=%d 3=%d 4=%d val=%d fail=%d\n",i,c[i][0],c[i][1],c[i][2],c[i][3],c[i][4],val[i],fail[i]);
vis[1][0][0]=f[1][0][0]=1;
for(int j=0;j<Nlen;++j){
for(int i=0;i<=cnt;++i) {
for(int limit=0;limit<=2;++limit) {
if(!vis[limit][i][j]) continue;
if(limit==0) {
for(int nx=0;nx<10;++nx)
if(!val[c[i][nx]])
add(f[0][c[i][nx]][j+1],f[0][i][j]),vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1);
} else if(limit==1){
for(int nx=(j==0);nx<N[j+1]-'0';++nx)
if(!val[c[i][nx]]) {
add(f[0][c[i][nx]][j+1],f[1][i][j]);
vis[0][c[i][nx]][j+1]=true;
pp(0,c[i][nx],j+1);
}
if(j==0){
add(f[2][0][1],f[1][0][0]);
vis[2][0][1]=true;
pp(2,0,1);
}
if(!val[c[i][ N[j+1]-'0' ]]) {
add(f[1][ c[i][ N[j+1]-'0' ]][j+1],f[1][i][j]);
vis[1][ c[i][ N[j+1]-'0' ] ][j+1]=true;
pp(1,c[i][N[j+1]-'0'],j+1);
}
} else if(limit==2) {
add(f[2][0][j+1],f[2][0][j]);
vis[2][0][j+1]=1;
pp(2,0,j+1);
for(int nx=1;nx<10;++nx)
if(!val[c[i][nx]]) {
add(f[0][c[i][nx]][j+1],f[2][0][j]);
vis[0][c[i][nx]][j+1]=true;
pp(0,c[i][nx],j+1);
}
}
}
}
}
for(int i=0;i<=cnt;++i) {
(ans+=f[0][i][Nlen])%=mod;
(ans+=f[1][i][Nlen])%=mod;
}
printf("%lld\n",ans);
return 0;
}
```