算法介绍

这是一类把集合转化为整数记录在DP状态中的算法。假设 1 代表持有某种状态,那么对于十进制数 (11)D = (1011)B 来说可以表示为持有第 4 2 1 3种状态。

如下是 num 所持有的状态的所有子集(包含空集)。

int sub = num;
do {
	// do sth
    sub = num & (sub - 1);
} while(sub != num)

通过子集预处理,可以减少枚举状态和一系列多余状态的判断

例题讲解

1. 棋盘类问题(基于连通性的DP)

棋盘类问题需要想清楚层与层之间的状态转移关系,通常是以 dp[i][s1] ~ dp[i-1][s2] 表示第 i 层状态 s1 是由 i-1 层状态转移而来,有时会牵扯到三层之间的关系,即 dp[i][s1] ~ dp[i-1][s2] and dp[i-2][s3]
本站相关例题讲解

例题 做法描述 代码
AcWing 1064. 小国王 dp[i][j][k]表示第i行状态为k且前i行共可以放j个国王 🔗
AcWing 327. 玉米田 dp[i][s] 表示第i层压缩状态位s的总种法 🔗
AcWing 292. 炮兵阵地 滚动数组 + dp[i][j][k]表示第i层状态j,i-1层状态为k 🔗

2. 集合类问题

// TODO

AcWing 1064. 小国王

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long LL;

int n, m;

vector<int> stateVec, head[2000];
LL dp[12][110][2000];
int cnt[2000];

bool check(int state) {
    for (int i = 0; i < n; i++) {
        if ((state >> i & 1) && (state >> (i + 1) & 1)) return false;
    }
    return true;
}

int count(int state) {
    int cnt = 0;
    for (int i = 0; i < n; i++) if (state >> i & 1) ++cnt;
    return cnt;
}

int main() {
    cin >> n >> m;

    for (int i = 0; i < (1 << n); i++) if (check(i)) stateVec.push_back(i), cnt[i] = count(i);
    for (int i = 0; i < stateVec.size(); i++) {
        for (int j = 0; j < stateVec.size(); j++) {
            int a = stateVec[i], b = stateVec[j];
            if (!(a & b) && check(a | b)) {
                head[a].push_back(b);
            }
        }
    }
    dp[0][0][0] = 1;
    for (int i = 1; i <= n + 1; i++) {
        for (int j = 0; j <= m; j++) {
            for (int k = 0; k < stateVec.size(); k++) {
                int a = stateVec[k];
                for (int t = 0; t < head[a].size(); t++) {
                    int c = cnt[a];
                    if (j >= c) dp[i][j][a] += dp[i - 1][j - c][head[a][t]];
                }
            }
        }
    }
    cout << dp[n + 1][m][0] << endl;
    return 0;
}

AcWing 327. 玉米田

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 1 << 13 + 10;
const int M = 14, MOD = 1e8;

int n, m, t;

int id[M];

long long dp[M][N];

bool check(int s) {
    for (int i = 0; i < m; i++) {
        if ((s >> i & 1) && (s >> (i + 1) & 1)) return false;
    }
    return true;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < m; j++) {
            cin >> t;
            id[i] = (id[i] << 1) + t;
        }
    }
    dp[0][0] = 1;
    long long ans = 0;
    for (int i = 1; i <= n; i++) {
        int pre = id[i - 1], cur = id[i];
        int s1 = cur;
        do {
            int s2 = pre;
            do {
                if (check(s1) && check(s2) && (s1 & s2) == 0) {
                    dp[i][s1] = (dp[i][s1] + dp[i - 1][s2]) % MOD;
                }
                s2 = (s2 - 1) & pre;
            } while (s2 != pre);
            if (i == n) ans = (ans + dp[i][s1]) % MOD;
            s1 = (s1 - 1) & cur;
        } while (s1 != cur);
    }
    cout << ans << endl;
    return 0;
}

AcWing 292. 炮兵阵地

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;
const int N = 110;
const int M = 1 << 10;
int n, m;

vector<int> id(N);
int dp[2][M][M];

bool check(int s) {
    for (int i = 0; i < m; i++) {
        int t = ((s >> i) & 1) + ((s >> (i + 1)) & 1) + ((s >> (i + 2)) & 1);
        if (t >= 2) return false;
    }
    return true;
}

int count(int s) {
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        if (s >> i & 1) ret++;
    }
    return ret;
}

int main() {
    cin >> n >> m;
    vector<int> head[N];
    head[0].push_back(0);
    head[1].push_back(0);
    for (int i = 2; i <= n + 1; i++) {
        for (int j = 0; j < m; j++) {
            char c;
            cin >> c;
            id[i] = (id[i] << 1) + (c == 'P');
        }
        int s = id[i];
        do {
            if (check(s)) head[i].push_back(s);
            s = (s - 1) & id[i];
        } while (s != id[i]);
    }


    int turn = 1;
    int ans = 0;
    for (int i = 2; i <= n + 1; i++) {
        turn ^= 1;
        for (int j = 0; j < head[i].size(); j++) {
            for (int k = 0; k < head[i - 1].size(); k++) {
                for (int v = 0; v < head[i - 2].size(); v++) {
                    int c = head[i][j], p = head[i - 1][k], pp = head[i - 2][v];
                    if (((c & p) == 0) && ((pp & p) == 0) && ((pp & c) == 0)) {
                        dp[turn][c][p] = max(dp[turn ^ 1][p][pp] + count(c), dp[turn][c][p]);
                        ans = max(ans, dp[turn][c][p]);
                    }
                }
            }
        }
    }
    cout << ans << endl;
    return 0;
}