Educational Codeforces Round 83: D. Count the Arrays

問題

問題文

問題概要

素数 n の数列 a を考える.

  • 各要素は 1 \leq a_i \leq m を満たす.
  • 各数列について, 等しい値をもつ要素が, ちょうど1組存在する.
  • 各数列 a について, 以下を満たす インデックス i が存在する.
    • i番目の前の数列は狭義単調増加, i番目以降の数列は狭義単調減少となる

このような数列は何通りあるか?

制約

  •  2 \leq \ n \leq m \leq 2 \cdot 10^{5}

解答例

指針

  • 数学を頑張る

解説

使う数字の種類は, 重複する要素が1組しかないので, 必ず n-1 種類となる. したがって. m 種類から n-1 種類を選ぶので, 使う数字の選び方は全部で _m \text{C}_{n-1} 通りとなる.

n-1種類の数字のどれかひとつを複製する必要があるが, 複製できるのは (n-1) 通りではない.

なぜなら, 最大値を複製すると3つ目の条件 (狭義単調増加/減少に関する条件),を満たすことはできないからである. よって複製できるのは最大値以外なので,  (n-2) 通り.

使う数字を決めたとき, それらをどのように並べるかを考える.

最大値および複製した値以外の要素について最大値よりも右に置くか, 左に置くかを考えれば良いので, 2^{n-3} 通り.

各要素を最大値よりも右に置くか, 左に置くかを決めると数列は一意に定まる.

したがって, 答えは,  _{m} \text{C}_{n-1} \cdot (n-2) \cdot 2^{n-3}

実装例

#include <iostream>
#include <algorithm>
#include <vector>
 
using namespace std;
 
const int mod = 998244353;
 
long long mod_pow(long long x, long long n, long long mod) {
  long long res = 1;
  while (n > 0) {
    if (n & 1) { res = res * x % mod; }
    x = x * x % mod;
    n >>= 1;
  }
  return res;
}
 
long long mod_inverse(long long x, long long mod) {
  return mod_pow(x, mod-2, mod);
}
 
long long mod_comb(long long n, long long k, long long mod) {
  long long numer = 1, denom = 1;
  for (long long i = 0; i < k; i++) {
    numer = numer * ((n-i) % mod) % mod;
    denom = denom * ((i+1) % mod) % mod;
  }
  return numer * mod_inverse(denom, mod) % mod;
}
 
int main() {
  long long n, m;
  cin >> n >> m;
 
  long long ans = mod_comb(m, n-1, mod);
  ans = (ans * (n-2)) % mod;
  ans = (ans * mod_pow(2, n-3, mod)) % mod;
 
  cout << ans << endl;
  return 0;
}
  • Python3
def mod_inverse(x, mod):
    return pow(x, mod-2, mod)
 
def mod_comb(n, k, mod):
    numer, denom = 1, 1
    for i in range(k):
        numer = numer * ((n-i) % mod) % mod
        denom = denom * ((i+1) % mod) % mod
 
    return numer * mod_inverse(denom, mod) % mod
 
 
n, m = map(int, input().split())
mod = 998244353
ans = mod_comb(m, n-1, mod)
ans = (ans * (n-2)) % mod
ans = (ans * pow(2, n-3, mod)) % mod
 
print(ans)

参考文献