AtCoder Beginner Contest 113: D - Number of Amidakuji

問題

問題文

https://atcoder.jp/contests/abc113/tasks/abc113_d

問題概要

あみだくじを作る.

まず W 本の平行な縦線を引き, 次にそれらを繋ぐ横線を引いていく.

それぞれの縦棒の長さは H + 1 であり, 横線の端点となれるのは上から 1, 2, 3 ... H の位置のみ.

また, あみだくじは以下のような条件を満たす必要がある.

  • どの2つの横棒も端点を共有しない.
  • 横棒の 2 つの端点は同じ高さ.
  • 横棒は隣り合う縦線を繋がなければならない.

縦棒 1 の上端からあみだくじを開始したとき, 最終的にたどり着く縦棒の番号が K となるあみだくじの総数を 109+7 で割った余りを求めよ.

  •  1 \leqq H \leqq 100
  •  1 \leqq W \leqq 8
  •  1 \leqq K \leqq W

解答例

指針

  • bitDP

解説

まず全探索ができるかを考える.

H = 100, W = 8 のケースでは, 横のつなげ方が隣り合う縦棒を横棒でつなぐか繋がないかを考えれば良いので 27 通り (実際には連続してつなげてはいけないという制約があるのでもっと小さい値になる. editorial によればその値は 34).

また, 縦についても考えると全部で (27)100 = 2700 となり全てを試すのは到底無理であることが分かる.

 dp(i,j) を高さが i までのあみだくじとみなしたときに左から j 番目の縦棒に到達するあみだくじの総数と定義する.

高さが 0 のあみだくじは横線の端点となる場所がない.

したがって縦棒 0 の上端からあみだくじを開始 (実装のことを考え問題設定とは異なり 0-based index で考えた) するので, 以下のような結果になることは容易に分かる.

{\displaystyle
\begin{eqnarray}
dp(0, x) =
  \left\{
    \begin{array}{l}
        1 \quad (x = 0) \\
        0 \quad (x \neq 0)
    \end{array}
  \right.
\end{eqnarray}
}

次に  dp(i, j) が分かっているときに  dp(i+1, j) がどうなるか考える.

これは高さ(i+1) における, 横棒の状態を全て列挙し, 左から j (0 <= j <= W-1) 番目からたどったとき, どこに遷移するかをシミュレートすることで  dp(i+1, j) を求めることができる.

横棒の状態の列挙に bit を用いることができる, このように bit を用いて状態を表して行う DP を俗に bitDP という.

具体例を考えてみよう.

入力例4 では以下のような入力が与えられる.

2 3 1

H = 2, W = 3, K = 1 となる.

高さ 0 のときはすでに分かっているので, 高さ 1 のときどのようにDPテーブルが更新されるか考える. 実際に紙に書いて実験すると分かるように dp[1][0] = 2, dp[1][1] = 1, dp[1][2] = 0 となる.

以下のようなコードでDPテーブルを更新できる.

高さ i のとき, 左から j 番目の縦棒にいたと仮定し, 全通りシミュレートで求めた遷移先 dp[i+1, x] に dp(i, j) の値を足している.

#include <iostream>

using namespace std;
 
typedef long long ll;
const ll mod = 1e9 + 7;

int main() {
    int H, W, K;
    cin >> H >> W >> K;

    ll dp[102][8] = {0};
    dp[0][0] = 1;

    int i = 0;

    for (int j = 0; j < W; j++) {
        // bit の有無で横棒を表現
        for (int bit = 0; bit < 1<<(W-1); bit++) {
            bool is_valid = true;
            for (int k = 0; k < W - 2; k++) {
                // 連続して bit が立っていたら invalid
                if ( ((bit >> k) & 1) && ((bit >> (k + 1)) & 1) ) {
                    is_valid = false;
                }
            }
            if (!is_valid) { continue; }
            // j 番目の bit が立っていた => 右側に遷移する
            if (bit >> j & 1) {
                dp[i+1][j+1] += dp[i][j];
                dp[i+1][j+1] %= mod;
            }
            // j-1 番目の bit が立っていた => 左側に遷移する
            else if (j > 0 && (bit >> (j - 1)) & 1) {
                dp[i+1][j-1] += dp[i][j];
                dp[i+1][j-1] %= mod;
            }
            // どちらの bit も立っていない => そのまま下に遷移
            else {
                dp[i+1][j] += dp[i][j];
                dp[i+1][j] %= mod;
            }
        }
    }

    for (int j = 0; j < W; j++) {
        cout << dp[1][j] << " ";
    }
    cout << endl;

    return 0;
}
  • 出力結果
$ ./a.out
2 3 1
2 1 0

この操作を繰り返すことで制限時間以内に答えを出力することができる.

時間計算量は O(H \times W \times 2^{W})

  • C++ による実装例
/* submission: https://atcoder.jp/contests/abc113/submissions/3854771
   Language: C++14 (GCC 5.4.1)
   time: 3 ms
   Memory: 256 KB */
#include <iostream>

using namespace std;
 
typedef long long ll;
const ll mod = 1e9 + 7;

int main() {
    int H, W, K;
    cin >> H >> W >> K;

    ll dp[102][8] = {0};
    dp[0][0] = 1;

    for (int i = 0; i < H; i++) {
        for (int j = 0; j < W; j++) {
            // bit の有無で横棒を表現
            for (int bit = 0; bit < 1<<(W-1); bit++) {
                bool is_valid = true;
                for (int k = 0; k < W - 2; k++) {
                    // 連続して bit が立っていたら invalid
                    if ( ((bit >> k) & 1) && ((bit >> (k + 1)) & 1) ) {
                        is_valid = false;
                    }
                }
                if (!is_valid) { continue; }
                // j 番目の bit が立っていた => 右側に遷移する
                if (bit >> j & 1) {
                    dp[i+1][j+1] += dp[i][j];
                    dp[i+1][j+1] %= mod;
                }
                // j-1 番目の bit が立っていた => 左側に遷移する
                else if (j > 0 && (bit >> (j - 1)) & 1) {
                    dp[i+1][j-1] += dp[i][j];
                    dp[i+1][j-1] %= mod;
                }
                // どちらの bit も立っていない => そのまま下に遷移
                else {
                    dp[i+1][j] += dp[i][j];
                    dp[i+1][j] %= mod;
                }
            }
        }
    }
    cout << dp[H][K-1] << endl;
    return 0;
}
  • Python3 による実装例
# submission: https://atcoder.jp/contests/abc113/submissions/3854758
# Language: Python3 (3.4.3)
# time: 208 ms
# Memory: 3064 KB

H, W, K = map(int, input().split())

dp = [[0]*W for _ in range(H+1) ] 
dp[0][0] = 1

mod = int(1e9 + 7)

for i in range(H):
    for j in range(W):
        for bit in range(2**(W-1)):
            is_valid = True
            for k in range(W-2):
                if (bit >> k & 1) and ((bit >> (k + 1)) & 1):
                    is_valid = False
            if not is_valid:
                continue

            if bit >> j & 1:
                dp[i+1][j+1] += dp[i][j]
                dp[i+1][j+1] %= mod
            elif j > 0 and (bit >> (j-1)) & 1:
                dp[i+1][j-1] += dp[i][j]
                dp[i+1][j-1] %= mod
            else:
                dp[i+1][j] += dp[i][j]
                dp[i+1][j] %= mod

print (dp[H][K-1])

参考文献