ひらめの日常

プログラミングと読書と

ABC127-E Cell Distance (500)

Cell Distance

問題はこちら atcoder.jp

数式があるので、それを文章にする。

NM 列のマス目のうち、K マスに駒をおく。このコストは全ての駒のペアのx座標の差 + y座標の差の和で計算される。これを全ての配置について和を取りなさい。」

考え方

愚直にやると、以下のようになり到底間に合わない。

  1. 全てのペアについて距離の合計を求める:O((NM)^{2})
  2. 全ての配置について試す: O(_ {NM} \mathrm{C} _ {K})(指数時間)

よって、それぞれについて高速化が必要である。

step1 - 和の計算の分解

 \sum _ {i=1}^{K-1} \sum _ {j=i+1}^{K}\left(\left|x _ {i}-x _ {j}\right|+\left|y _ {i}-y _ {j}\right|\right) について、xyは分解して考えることができる。

 \sum _ {i=1}^{K-1} \sum _ {j=i+1}^{K}\left|x _ {i}-x _ {j}\right| + \sum _ {i=1}^{K-1} \sum _ {j=i+1}^{K}\left|y _ {i}-y _ {j}\right|

よって、 x座標についてのコストの総和を計算し、同様にして y座標のコスト総和を計算し、それぞれを足すことで答えを得ることができる。これからはx座標のコストのみを考えていくことにする。

step2 - 計算式の整理

2点間の距離 dを固定して考えてみる。 すると、「考え方」で述べた二つの過程は次のように言い換えることができる。

  1. 全てのペアについて距離の合計を求める -> 距離 d となるようなペアの個数
  2. 全ての配置について試す -> そのペアが使われるような配置の場合の数
  3. 上記を全ての距離について試す。

step3 - 距離dとなるようなペアの個数

距離 dとなるようなペアの個数は、以下のようにして求めることができる。

まず、距離 dとなるような列の取り方が、 (M - d) 通り。そして、各列においてどの行の座標を使うかの組み合わせが N^{2} 通り。以上より、 (M - d) N^{2} 通り。

step4 - そのペアが使われるような配置の場合の数

 NM 個の座標のうち、そのペア以外の  (NM - 2) 個の候補から、座標を  (K - 2) 個えらぶような場合の数なので、 _ {NM - 2} \mathrm{C} _ {K - 2} 通り。

step 5 - 全ての距離について計算する。

疑似言語で書くとこんな感じになる。

int sumx = 0
for d in 0 to M-1:
  sumx += d * (M - d) * N * N

sumx *= combination(N * M - 2,  K  - 2)

...
同様にしてN, Mを入れ替えてsumyも計算する
...
print(sumx + sumy)

解答

modを取ったり、combinationでmodの逆元をとることなどを忘れないようにする。

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll mod = 1000000007;

#define rep(i, n) for(ll i = 0; i < (ll)(n); i++)

/* ------------- ANSWER ------------- */
/* ---------------------------------- */


const int MAX = 510000;
long long fac[MAX], finv[MAX], inv[MAX];

// テーブルを作る前処理
void COMinit() {
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++) {
        fac[i] = fac[i - 1] * i % mod;
        inv[i] = mod - inv[mod % i] * (mod / i) % mod;
        finv[i] = finv[i - 1] * inv[i] % mod;
    }
}

// 二項係数計算
ll COM(int n, int k) {
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % mod) % mod;
}

int main() {

    ll n, m, k;
    cin >> n >> m >> k;
    COMinit();

    // sum for x
    ll sum_x = 0;
    rep(i, m) sum_x += i * (m - i) * n * n;
    sum_x %= mod;
    sum_x *= COM(n * m - 2, k - 2);
    sum_x %= mod;

    // sum for y
    ll sum_y = 0;
    rep(i, n) sum_y += i * (n - i) * m * m;
    sum_y %= mod;
    sum_y *= COM(n * m - 2, k - 2);
    sum_y %= mod;

    cout << (sum_y + sum_x) % mod << endl;
    return 0;
}

Submission #6260827 - AtCoder Beginner Contest 127