AtCoder Beginner Contest 107: D - Median of Medians

問題

問題文

https://atcoder.jp/contests/abc107/tasks/arc101_b

問題概要

長さ N の数列 a が与えられる.

区間 [l, r] (1≤l≤r≤N) について, a のすべての連続部分列の中央値を各々求め, その中央値を並べ新たに数列 m を作る.

このとき m の中央値を求めよ.

ただし, ここで中央値は 長さ M の数列を昇順ソートしたとき M/2 + 1 番目の要素の値とする. (ここでの除算は切り捨て)

制約

  •  1 \leqq N \leqq 10^{5}

  •  a_i は整数である。

  •  1 ≤ a_i ≤ 10^{9}

a = {10, 30, 20} のとき,

部分列は {10}, {30}, {20}, {10, 30}, {30, 20}, {10, 30, 20} の 6 つがあり, それぞれの中央値を並べた数列を m とすると,

m = {10, 30, 20, 30, 30, 20} となり m の中央値は 30 である.

よって求める値は 30 となる.

解答例

指針

  • 答えで二分探索

  • 転倒数(のようなもの)に落とし込む

解説

数列 a = {a_1, a_2, ... a_n} の中央値が X以上であるかどうかを判定するには数列の各要素 a_i のうち X以上のものが ceil(n/2) 個以上あるかどうかを見ればいい事がわかる.

数列 a の中央値が X 以上であるかどうかの判定ができるので, 中央値を二分探索により求めることができる.

今回求めるのは数列 a の連続部分列の中央値を並べた数列 m の中央値なので aの連続部分列 a_{l,r} のうち X 以上の要素が ceil( (r-l+1) / 2 ) 個以上となるものが何通りであるかを知る必要がある. 愚直に計算すると間に合わないので高速な方法を考える.

数列 a の要素のうち X 以上のものを +1, X 未満のものを -1 に置き換えた数列を T とすると, T の連続部分列のうち総和が0以上となるものが何通りかを調べればよくなる.

置き換えて得られた数列の累積和をとりその数列を S とする. (初項を 0 をとする)

このとき, 任意の区間 [l, r] で数列 T の和が 0 以上となるのは 数列 S において i < j かつ S[j] - S[i] >= 0 が成り立つような (i, j) の組の個数となる.

これは転倒数とほとんど同じである.

転倒数は  O(n \log n) で求めるアルゴリズムが存在する.

https://kira000.hatenadiary.jp/entry/2019/02/23/053917

  • 転倒数

    • 数列 a において i < j かつ a_i > a_j を満たす (i, j) の組の個数
  • 今回

    • 数列 a において i < j かつ a_j >= a_i を満たす (i, j) の組の個数

このとき求めた転倒数のようなものの値がすべての部分列の要素の個数の和の過半数であればよい. 部分列の要素の個数の和 は n + n-1 + n-2 + ... + 1 = n(n+1)/2 なので, n(n+1)/2/2 より大きいかを調べればよい.

  • C++ による実装例 (分割統治)
// submission: https://atcoder.jp/contests/abc107/submissions/4356101
// Language: C++14 (GCC 5.4.1)
// Time: 860 ms
// Memoly: 3424 KB

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

typedef long long ll;

ll merge_cnt(vector<ll> &a) {
    int n = a.size();
    if (n <= 1) { return 0; }

    ll cnt = 0;
    vector<ll> b(a.begin(), a.begin()+n/2);
    vector<ll> c(a.begin()+n/2, a.end());

    cnt += merge_cnt(b);
    cnt += merge_cnt(c);

    int ai = 0, bi = 0, ci = 0;
    // merge の処理
    while (ai < n) {
        if ( bi < b.size() && (ci == c.size() || b[bi] >= c[ci]) ) {
            a[ai++] = b[bi++];
        } else {
            cnt += n/2 - bi;
            a[ai++] = c[ci++];
        }
    }
    return cnt;
}

bool judge(vector<ll> &a, ll x) {
    ll n = a.size();
    ll sum_v = 0;
    vector<ll> s = {0};

    for (auto &e: a) {
        int t = (e <= x) ? 1 : -1;
        sum_v += t;
        s.push_back( sum_v );
    }

    return merge_cnt(s) > n*(n+1)/2/2;
}

int main() {
    int n; cin >> n;

    vector<ll> a(n);

    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }

    ll high = 1e18, low = 0;

    while (low < high) {
        ll mid = ( high + low ) / 2;

        if ( judge(a, mid) ) {
            high = mid;
        } else {
            low = mid+1;
        }
    }

    cout << high << endl;

    return 0;
}
  • C++ による実装例 (Fenwick Tree)
// submission: https://atcoder.jp/contests/abc107/submissions/4356196
// Language: C++14 (GCC 5.4.1)
// Time: 178 ms
// Memoly: 1804 KB

#include <iostream>
#include <vector>

using namespace std;

typedef long long ll;

struct fenwick_tree {
    typedef int T;
    T n;
    vector<T> bit;
    fenwick_tree(T num) : bit(num+1, 0) { n = num; }
    void add(T i, T w) {
        for (T x = i; x <= n; x += x & -x) {
            bit[x] += w;
        }
    }
    T sum(T i) {
        T ret = 0;
        for (T x = i; x > 0; x -= x & -x) {
            ret += bit[x];
        }
        return ret;
    }
};

ll a[100005];

bool judge(ll mid, ll n) {
    // offset => 履かせる下駄
    int offset = n+1;
    ll num_inverse = 0;
    fenwick_tree bit(n*2+10);

    int sum_v = 0;
    bit.add(offset, 1);

    for (int i = 0; i < n; i++) {
        int s = (a[i] <= mid) ? 1 : -1;
        sum_v += s;
        num_inverse += (bit.sum(sum_v-1+offset));
        bit.add(sum_v+offset, 1);
    }
    return num_inverse > (n+1)*n/2/2; 
}

int main() {
    ll n; cin >> n;

    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }

    ll high = 1e18, low = 0; 

    while (low < high) {
        ll mid = ( high + low ) / 2;
        if ( judge(mid, n) ) {
            high = mid;
        } else {
            low = mid+1;
        }
    }

    cout << high << endl;

    return 0;
}

感想

説明を聞けばそれはそうという感じだけど, まだ自力でこのレベルを解ける気がしない.

参考文献