POJ 3415 Common Substrings(后缀数组)

Description

A substring of a string T is defined as:

 

T( ik)= TiTi +1... Ti+k -1, 1≤ ii+k-1≤| T|.

 

Given two strings AB and one integer K, we define S, a set of triples (ijk):

 

S = {( ijk) |  kKA( ik)= B( jk)}.

 

You are to give the value of |S| for specific AB and K.

Input

The input file contains several blocks of data. For each block, the first line contains one integer K, followed by two lines containing strings A and B, respectively. The input file is ended by K=0.

1 ≤ |A|, |B| ≤ 105
1 ≤ K ≤ min{|A|, |B|}
Characters of A and B are all Latin letters.

Output

For each case, output an integer |S|.

 

题目大意:给两个字符串,问有多少个长度大于等于K的公共子串。

思路:首先,把两个字符串用一个未出现过的字符(如'$')连起来,求后缀数组和height[]数组。

用每个后缀的所有前缀代表一个字符串的所有子串。

然后,按height[]的顺序从前往后扫描。

遇到第一个字符串的,就压入栈中。遇到第二个字符串的,就计算栈中与第二个字符串的长度大于等于K的公共前缀。

对于栈中每一个height[],它与当前第二个字符串的长度大于等于K的公共前缀一共有height[]-k+1个。

sum{height[]-k+1}可以在压栈的同时统计。

用一个单调栈维护,让每个height[]只入栈和出栈一次。

最后rank小的第一个字符串和rank大的第二个字符串的长度大于等于K的公共前缀就统计出来了,统计复杂度为O(n)。

此时两个字符串反过来再做一遍即可。

 

代码(1469MS):

  1 #include <cstdio>

  2 #include <iostream>

  3 #include <cstring>

  4 #include <algorithm>

  5 #include <stack>

  6 using namespace std;

  7 typedef long long LL;

  8 

  9 const int MAXN = 200010;

 10 

 11 char s[MAXN];

 12 int sa[MAXN], rank[MAXN], height[MAXN], c[MAXN], tmp[MAXN];

 13 int n, apart, k;

 14 

 15 void makesa(int m) {

 16     memset(c, 0, m * sizeof(int));

 17     for(int i = 0; i < n; ++i) ++c[rank[i] = s[i]];

 18     for(int i = 1; i < m; ++i) c[i] += c[i - 1];

 19     for(int i = 0; i < n; ++i) sa[--c[rank[i]]] = i;

 20     for(int k = 1; k < n; k <<= 1) {

 21         for(int i = 0; i < n; ++i) {

 22             int j = sa[i] - k;

 23             if(j < 0) j += n;

 24             tmp[c[rank[j]]++] = j;

 25         }

 26         int j = c[0] = sa[tmp[0]] = 0;

 27         for(int i = 1; i < n; ++i) {

 28             if(rank[tmp[i]] != rank[tmp[i - 1]] || rank[tmp[i] + k] != rank[tmp[i - 1] + k])

 29                 c[++j] = i;

 30             sa[tmp[i]] = j;

 31         }

 32         memcpy(rank, sa, n * sizeof(int));

 33         memcpy(sa, tmp, n * sizeof(int));

 34     }

 35 }

 36 

 37 void calheight() {

 38     for(int i = 0, k = 0; i < n; height[rank[i++]] = k) {

 39         k -= (k > 0);

 40         int j = sa[rank[i] - 1];

 41         while(s[i + k] == s[j + k]) ++k;

 42     }

 43 }

 44 

 45 struct Node {

 46     int height, cnt;

 47     Node(int height = 0, int cnt = 0): height(height), cnt(cnt) {}

 48 };

 49 

 50 LL solve() {

 51     LL ans = 0, sum = 0;

 52     stack<Node> stk;

 53 

 54     for(int i = 1; i < n; ++i) {

 55         int cnt = 0;

 56         while(!stk.empty() && stk.top().height >= height[i]) {

 57             Node t = stk.top(); stk.pop();

 58             cnt += t.cnt;

 59             sum -= t.cnt * (t.height - k + 1LL);

 60         }

 61         if(height[i] >= k) {

 62             cnt += (sa[i - 1] < apart);

 63             if(cnt) stk.push(Node(height[i], cnt));

 64             sum += cnt * (height[i] - k + 1LL);

 65         }

 66         if(sa[i] > apart) ans += sum;

 67     }

 68 

 69     while(!stk.empty()) stk.pop();

 70     sum = 0;

 71 

 72     for(int i = 1; i < n; ++i) {

 73         int cnt = 0;

 74         while(!stk.empty() && stk.top().height >= height[i]) {

 75             Node t = stk.top(); stk.pop();

 76             cnt += t.cnt;

 77             sum -= t.cnt * (t.height - k + 1LL);

 78         }

 79         if(height[i] >= k) {

 80             cnt += (sa[i - 1] > apart);

 81             stk.push(Node(height[i], cnt));

 82             sum += cnt * (height[i] - k + 1LL);

 83         }

 84         if(sa[i] < apart) ans += sum;

 85     }

 86 

 87     return ans;

 88 }

 89 

 90 int main() {

 91     while(scanf("%d", &k) != EOF && k) {

 92         scanf("%s", s);

 93         apart = strlen(s);

 94         s[apart] = '$';

 95         scanf("%s", s + apart + 1);

 96         n = strlen(s) + 1;

 97         makesa(128);

 98         calheight();

 99         cout<<solve()<<endl;

100     }

101 }
View Code

 

你可能感兴趣的:(substring)