题目来源:2019湖南省省赛 C题
现在有一个长为n(1e6)、字符集为m(1e6)的字符串S,对于一个字符c定义 h ( c ) = S c 的 不 同 子 串 个 数 − S 的 不 同 子 串 个 数 h(c)=Sc的不同子串个数-S的不同子串个数 h(c)=Sc的不同子串个数−S的不同子串个数,即在尾部添加c后多了几个不同子串,求 Σ c = 1 m h ( c ) \Sigma_{c=1}^mh(c) Σc=1mh(c)。
多组数据, Σ n , Σ m < = 5 e 5 \Sigma n,\Sigma m<=5e5 Σn,Σm<=5e5,空间限制32M,时间限制1s
对S建立后缀自动机,因为字符集较大,使用map来保存转移。
在后缀自动机中, 每个节点对应的不同子串个数等于len(u)-len(link(u))。
枚举每个字符c会添加多少个不同子串,注意到尾部添加一个字符最多会创建两个节点:cur和clone。
而link(clone) = 原有的link(q),且新的link(q) = clone,所以clone节点并不会存储新的字符串。
而cur的link要么是0,要么是某个q,依照插入规则去找即可得到答案。
inline int try_extend(int c) const
{
int p = lst;
while(!ch[p].count(c) && p)
p = link[p];
if(ch[p].count(c))
{
return len[lst] - len[p];
}
else
{
return len[lst] + 1;
}
}
此方法失败,爆空间。
首先求出S的后缀数组,注意其本质不同的子串个数等于 n ∗ ( n + 1 ) / 2 − Σ H e i g h t n * (n+1)/2 - \Sigma Height n∗(n+1)/2−ΣHeight。
尾部新插入一个字符c时,新增加的不同子串个数等于 n + 1 − h e i g h t 的 增 量 n+1 - height的增量 n+1−height的增量。
当S中没有c时,height增量为0;有c时,至少为1,因为新添加了一个"c"后缀。
其余哪些height会增加?
注意height[i]表示sa[i]与sa[i-1]的lcp,
所以当 h e i g h t [ i ] = l e n ( s a [ i − 1 ] ) height[i] = len(sa[i-1]) height[i]=len(sa[i−1])且 S [ s a [ i ] + h e i g h t [ i ] ] = c S[sa[i]+height[i]]=c S[sa[i]+height[i]]=c时,S末尾添加一个字符c,height会加一。
其中len(sa[i-1])表示第sa[i-1]个后缀的长度,值为n-sa[i-1]。
记cnt[c]为字符为c时height的增量,遍历一次height数组即可找到所有答案。
for(int i=2; i<=n; ++i)
if(height[i] == n-sa[i-1])
++cnt[str[sa[i]+height[i]]];
此方法失败,倍增法求SA超时,DC3求SA爆空间。
仔细思考上面两种方法的原理,Sc相比S,新增加了n+1个子串,分别是[1…n+1],[2…n+1],…,[n+1…n+1].
这些子串如果在S中已经出现过,就不能再被计数。注意当[k…n+1]出现过时,[k+1…n+1],[k+2,n+1],…作为它的后缀也一定出现过。
所以如果能够找到一个Sc的最长后缀,使得其在S中出现过,那么 h ( c ) = n + 1 − 后 缀 长 度 h(c)=n+1-后缀长度 h(c)=n+1−后缀长度。
注意这个后缀的结尾一定是c。
现在对于S的每个真前缀s[1…i],求其与S的最长公共后缀lcs,然后就可以更新 h ( s [ i + 1 ] ) = m i n ( h ( s [ i + 1 ] ) , ( n + 1 ) − ( l c s + 1 ) ) h(s[i+1]) = min(h(s[i+1]), (n+1)-(lcs+1)) h(s[i+1])=min(h(s[i+1]),(n+1)−(lcs+1)).
将S翻转后使用Z函数,就可以得到所有的lcs。
/* LittleFall : Hello! */
#include
using namespace std; using ll = long long; inline int read();
const int M = 1000016, MOD = 1000000007;
int z[M]; //z函数,z[i]表示字符串s的第i个后缀和s的lcp,不包括0位置。
void exkmp(int *s)
{
for(int i=1,l=0,r=0; s[i]; ++i)
{
z[i] = i>r ? 0 : min(r-i+1, z[i-l]);
while(s[i+z[i]] && s[z[i]]==s[i+z[i]]) ++z[i];
if(i + z[i] - 1 > r) l=i, r=i+z[i]-1;
}
}
int str[M], h[M];
ll p3[M];
int main(void)
{
#ifdef _LITTLEFALL_
freopen("in.txt","r",stdin);
#endif
p3[0] = 1;
for(int i=1; i<M; ++i)
p3[i] = p3[i-1] * 3 % MOD;
int n=0, m=0;
while(~scanf("%d%d",&n,&m))
{
for(int i=1; i<=m; ++i)
h[i] = n+1;
for(int i=n-1; i>=0; --i)
{
str[i] = read();
h[str[i]] = n;
}
str[n]=0;
exkmp(str);
for(int i=1; i<n; ++i)
h[str[i-1]] = min(h[str[i-1]], n-z[i]);
ll ans = 0;
for(int i=1; i<=m; ++i)
ans ^= h[i]*p3[i] % MOD;
printf("%lld\n",ans );
}
return 0;
}
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
总结: