AtCoder Beginner Contest 098: D - Xor Sum 2

問題

問題文

https://atcoder.jp/contests/abc098/tasks/arc098_b

問題概要

長さ N の整数列 A が与えられる.

次の条件を満たす整数 l, r (1 <= l <= r <= N) の組の個数を求めよ.

 A_l xor A_{l+1} xor ... xor A_r = A_l + A_{l+1} + ... + A_r

制約

  •  1 \leqq N \leqq 2 \times 10^{5}

  •  0 \leqq A_i \leqq 2^{20}

  • 入力は整数

解答例

指針

  • 尺取法

解説

条件を満たすのはどのようなときかを考える.

要素が1つだけのときは当然条件を満たす.

要素が2つのときはどうなるだろうか.

A = 4, B = 9 のとき

A = 0100(2)
B = 1001(2)

A + B = 13, A xor B = 1101(2) = 13 となり条件を満たす.

A = 13, B = 5 のとき

A = 1101(2)
B = 0101(2)

A + B = 18, A xor B = 1000(2) = 8 となり条件を満たさない.

条件を満たすときと満たさないときの違いはどこから生まれるのかを考える.

結論から言うと部分列の各要素を2進数表記したときの各桁について 1 の出現回数の合計が 1 以下のとき条件を満たす.

これは2進数同士の足し算を行った際, 繰り上がりが起こらないとき条件を満たすと言い換えることができる.

2進数の各桁ごと見ると,

0 + 0  = 0 xor 0 = 0
0 + 1  = 0 xor 1 = 1
1 + 0  = 1 xor 0 = 1
1 + 1 != 1 xor 1

であり, 1 がたかだか1回しか出現しないとき加算と排他的論理和は演算結果が一致する. 1 が2回以上出現すると, 加算では繰り上がりが起こるが XOR では必ず結果は 1 か 0 であり, 加算で得られる結果より小さくなる.

もしも部分列 T が条件を満たさないとき, T を含む部分列T'も条件を満たさない. したがって, 尺取法の要領でこの問題を O(N) で解くことができる.

  • C++ による実装例
// submission: https://atcoder.jp/contests/abc098/submissions/4748189 
// Language: C++14 (GCC 5.4.1)
// Time: 80 ms
// Memoly: 1024 KB

#include <iostream>

using namespace std;

typedef long long ll;

int main() {
    int n; cin >> n;
    int a[200005] = {0};

    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }

    ll ans = 0;
    ll s = 0, x = 0, r = 0;

    for (int l = 0; l < n; l++) {
        while ( r < n && ( (s+a[r]) == (x^a[r]) ) ) {
            s += a[r];
            x ^= a[r];
            r++;
        }
        ans += (r-l);

        if (l == r) { r++; }
        else {
            s -= a[l];
            x ^= a[l];
        }
    }
    cout << ans << endl;

    return 0;
}
  • Python3 による実装例
# submission: https://atcoder.jp/contests/abc098/submissions/4747917
# Language: Python3 (3.4.3)
# Time: 362 ms
# Memoly: 23560 KB

n = int(input())
a = [ int(x) for x in input().split() ]

ans = 0
s = 0 # 区間和
x = 0 # 区間排他的論理和?

r = 0
for l in range(n):
    while r < n and s+a[r] == x^a[r]:
        s += a[r]
        x ^= a[r]
        r += 1

    ans += (r-l) # 条件を満たす部分列の長さだけ ans を加算

    if l == r:
        r += 1

    else:
        s -= a[l]
        x ^= a[l]

print(ans)

参考文献