POJ 1795 DNA Laboratory 状态压缩DP(旅行商问题)

一、题目大意

我们有N个字符串,每个长度介于1到100,现要求构建一个组合串,使得所有字符串都为组合串的子串,找到长度最小的组合串,如果有多种可能,输出字典序排序最小的组合串。

二、解题思路

我们来回忆下状压DP解决旅行商问题,DP[S][v]代表已经走过的点为S,并从v开始走完剩余节点的最小距离。

其实你仔细思考,发现过滤掉那些 互为子串的字符串,之后剪掉首尾相接的公共部分,其实最终的组合串其实就是这些字符串互相拼接,而本题目其实就是要找出最佳的拼接顺序,就和旅行商问题异曲同工。

那么本题差不多,DP[S][v]代表已经包含字符串的集合为S,且结尾的字符串为 v时,去连接剩余字符串的最小长度(同时也记录连接的下一个字符串)。

首先预处理:1、把互为子串的字符串只留下长的那个。2、对于多个字符串相等只留下其中一个。

然后我们可知 dp[全集][ 0.. n] 都为0。

之后依次循环大小为 n-1,n-2...1的集合,执行如下递推式

1、如果 dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度  < dp[S][v]

则 dp[S][v] =  dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度

2、如果 dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度  == dp[S][v]

则对比较两种情况的字符串,如果以u开头字典序更小,dp[S][v] =  dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度。

我们可以在dp数组保存两个变量,下一次连接的字符串和拼接剩余所有字符串的长度。

这样比较字符字典序的时候,一开始设置 nxt = u,used = S,然后使 used = used | 1<< u,找到 dp[used][nxt][0]就是下一个字符串。下一个字符串的有效起始下标就是编号为 nxt 的字符串的尾 和 编号为 dp[used][nxt][0]字符串的首的公共长度。之后更新 nxt = dp[used][nxt][0]即可。

这样就可以通过某个位置,计算出全部的链路,进行字典序比较。

DP计算结束后,需要记录答案的起始字符串下标 ans,和答案长度,然后从0到n定义变量 i 循环。

如果 dp[1 << i] +len[i] < ansLen

则 ansLen=dp[1 << i] +len[i], ans = i;

如果 dp[1 << i] +len[i] == ansLen 且 i 编号的字符串的字典序 比 ans编号的字符串小

则 ansLen=dp[1 << i] +len[i], ans = i;

(过滤掉了互为子串的情况,则只比较开头的字符串即可)

我们定义数组 merg[i][j]代表字符串 i 的尾 和 字符串 j 的首的公共部分长度。

然后输出答案的时候,可以定义 nxt = ans,定义used = 0,然后定义下标 i,执行 ansLen次循环,每次循环输出一个字符。

for k in [0,ansLen) {

if i >= len[nxt] {

 used = used | 1<< nxt, i = merg[ nxt ][ dp[ used ][ nxt ][ 0 ] ]; nxt = dp[ used ][ nxt ][ 0 ];

}

输出 nxt下标的字符串的第 i 个字符,不要换行。

i++

}

最后需要输出两个 \n\n,不然 presentation error。

备注:本题目我挂了很多次,后来不使用 %s输入字符,不使用 strcmp 和 strlen,而是使用 %c输入,自己计数,最终过了,不清楚是不是因为不能用 %s、strcmp 和 strlen。

三、代码

#include 
using namespace std;
const int MAX_N = 15, MAX_LEN = 110, INF = 0x3f3f3f3f;
char tmpStr[MAX_N][MAX_LEN], dat[MAX_N][MAX_LEN];
int len[MAX_N], tmp[MAX_N], dp[1 << MAX_N][MAX_N][2], merg[MAX_N][MAX_N], n, num, ans, ansLen, all;
bool need[MAX_N];
void putAns()
{
    cout << "Scenario #" << num << ":" << endl;
    int used = 0;
    int nxt = ans;
    int j = 0;
    for (int k = 0; k < ansLen; k++)
    {
        if (j >= len[nxt])
        {
            used = used | 1 << nxt;
            j = merg[nxt][dp[used][nxt][0]];
            nxt = dp[used][nxt][0];
        }
        printf("%c", dat[nxt][j]);
        j++;
    }
    printf("\n");
    printf("\n");
}
bool compareStr(int prv, int a, int b, int used, int _len)
{
    int i = merg[prv][a], j = merg[prv][b], nxt1 = a, nxt2 = b, used1 = used, used2 = used;
    char c1, c2;
    for (int k = 0; k < _len; k++)
    {
        if (i >= len[nxt1])
        {
            used1 = used1 | 1 << nxt1;
            i = merg[nxt1][dp[used1][nxt1][0]];
            nxt1 = dp[used1][nxt1][0];
        }
        if (j >= len[nxt2])
        {
            used2 = used2 | 1 << nxt2;
            j = merg[nxt2][dp[used2][nxt2][0]];
            nxt2 = dp[used2][nxt2][0];
        }
        if (dat[nxt1][i] != dat[nxt2][j])
        {
            return dat[nxt1][i] < dat[nxt2][j];
        }
        i++;
        j++;
    }
    return false;
}
void handleStr(int used, int v, int u)
{
    if (dp[used | 1 << u][u][1] + len[u] - merg[v][u] > dp[used][v][1])
    {
        return;
    }
    if (dp[used | 1 << u][u][1] + len[u] - merg[v][u] == dp[used][v][1] && !compareStr(v, u, dp[used][v][0], used, dp[used][v][1]))
    {
        return;
    }
    dp[used][v][0] = u;
    dp[used][v][1] = dp[used | 1 << u][u][1] + len[u] - merg[v][u];
}
void handle(int size)
{
    int used = 0;
    for (int i = 0; i < size; i++)
    {
        used = used | 1 << i;
    }
    while (used < all)
    {
        for (int v = 0; v < n; v++)
        {
            for (int u = 0; u < n; u++)
            {
                if ((used >> v & 1) && !(used >> u & 1))
                {
                    handleStr(used, v, u);
                }
            }
        }
        int x = used & -used;
        int y = used & ~(used + x);
        used = used + x + (y / x / 2);
    }
}
bool compareAns(int i)
{
    for (int k = 0; k < min(len[i], len[ans]); k++)
    {
        if (dat[i][k] != dat[ans][k])
        {
            return dat[i][k] < dat[ans][k];
        }
    }
    return false;
}
void doDp()
{
    for (int i = 0; i < (1 << MAX_N); i++)
    {
        for (int j = 0; j < MAX_N; j++)
        {
            dp[i][j][0] = INF;
            dp[i][j][1] = INF;
        }
    }
    all = (1 << n) - 1;
    for (int i = 0; i < n; i++)
    {
        dp[all][i][0] = 0;
        dp[all][i][1] = 0;
    }
    for (int i = n - 1; i > 0; i--)
    {
        handle(i);
    }
    ansLen = INF;
    for (int i = 0; i < n; i++)
    {
        if (dp[1 << i][i][1] + len[i] < ansLen)
        {
            ans = i;
            ansLen = dp[1 << i][i][1] + len[i];
        }
        else if (dp[1 << i][i][1] + len[i] == ansLen && compareAns(i))
        {
            ans = i;
            ansLen = dp[1 << i][i][1] + len[i];
        }
    }
}
void mergeStr()
{
    for (int i = 0; i < MAX_N; i++)
    {
        for (int j = 0; j < MAX_N; j++)
        {
            merg[i][j] = 0;
        }
    }
    for (int v = 0; v < n; v++)
    {
        for (int u = 0; u < n; u++)
        {
            for (int st = len[u] - 1; st >= 0 && len[u] - st <= len[v]; st--)
            {
                for (int k = 0; st + k < len[u]; k++)
                {
                    if (dat[v][k] != dat[u][st + k])
                    {
                        break;
                    }
                    if (st + k + 1 == len[u])
                    {
                        merg[u][v] = max(merg[u][v], len[u] - st);
                    }
                }
            }
        }
    }
}
void filterInclude()
{
    for (int v = 0; v < n; v++)
    {
        tmp[v] = len[v];
        for (int k = 0; k < len[v]; k++)
        {
            tmpStr[v][k] = dat[v][k];
        }
    }
    int v = 0;
    for (int i = 0; i < n; i++)
    {
        if (!need[i])
        {
            continue;
        }
        for (int k = 0; k < tmp[i]; k++)
        {
            dat[v][k] = tmpStr[i][k];
        }
        len[v] = tmp[i];
        v++;
    }
    n = v;
}
void findInclude()
{
    fill(need, need + MAX_N, true);
    for (int v = 0; v < n; v++)
    {
        for (int u = 0; u < n; u++)
        {
            if (!need[v] || !need[u] || v == u || len[v] > len[u])
            {
                continue;
            }
            for (int st = 0; st < len[u] && st + len[v] <= len[u]; st++)
            {
                for (int k = 0; k < len[v]; k++)
                {
                    if (dat[v][k] != dat[u][st + k])
                    {
                        break;
                    }
                    if (k + 1 == len[v])
                    {
                        need[v] = false;
                    }
                }
            }
        }
    }
}
void input()
{
    char c;
    scanf("%d\n", &n);
    for (int i = 0; i < n; i++)
    {
        len[i] = 0;
        while (true)
        {
            scanf("%c", &c);
            if (c == '\n')
            {
                break;
            }
            else
            {
                dat[i][len[i]] = c;
                len[i] = len[i] + 1;
            }
        }
    }
}
int main()
{
    int T = 0;
    scanf("%d", &T);
    for (num = 1; num <= T; num++)
    {
        input();
        findInclude();
        filterInclude();
        mergeStr();
        doDp();
        putAns();
    }
    return 0;
}

你可能感兴趣的:(动态规划,算法)