ひらめの日常

日常のメモをつらつらと

AtCoder: ABC123-D Cake 123 (400)

問題はこちら

atcoder.jp

美味しさは以下のように表される。

  •  X 種類のケーキ  A _ 1, A _ 2, ..., A _ X
  •  Y 種類のケーキ  B _ 1, B _ 2, ..., B _ Y
  •  Z 種類のケーキ  C _ 1, C _ 2, ..., C _ Z

この時、それぞれのケーキ美味しさの合計として大きい順に  K 個出力しなさい。

 \begin{array}{l}{1 \leq X \leq 1000}, \ {1 \leq Y \leq 1000}, \ {1 \leq Z \leq 1000}, \ {1 \leq K \leq \min (3000, X \times Y \times Z)}\end{array}

考え方

 X \times Y \times Z の全探索をすると  O(10 ^ 9) のループが周り、さらに大きい順に出力するので間に合わない。そこで、計算量を落とすことを考える。

いろんな解法があり、勉強になったので3つほど載せておく。元となる考え方は K が最大で3000 というところに注目するところにある。

解答1 - 解の候補を絞る

 A B の組み合わせを  X \times Y 分だけ全列挙すると、 O(10 ^ 6) となるので、sortしても大丈夫。この配列を  A \times B とおく。

大きい方から  K 個ということは、以下が成り立つ。

  •  A \times B からは最大でも  K 個使われる。
  •  C からは最大でも  K 個使われる。

よって、それぞれの配列の大きい方から  K ずつだけループを回して、大きい順に値を保持し、最後にsortして  K 個分出力すれば良い。計算量はここのsortがボトルネックとなって、  O(K ^ 2 \log(K))

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
using vl = vector<ll>;

#define rep(i, n) for(ll i = 0; i < (ll)(n); i++)
#define all(obj) (obj).begin(), (obj).end()

int main() {
    ll x, y, z, k;
    cin >> x >> y >> z >> k;
    vl a(x), b(y), c(z);
    rep(i, x) cin >> a[i];
    rep(i, y) cin >> b[i];
    rep(i, z) cin >> c[i];

    vl ab;
    rep(i, x) rep(j, y) ab.emplace_back(a[i] + b[j]);
    sort(all(ab), greater<>());
    sort(all(c), greater<>());

    vl ans;
    rep(i, min(k, x * y)) rep(j, min(k, z)) ans.emplace_back(ab[i] + c[j]);
    sort(all(ans), greater<>());

    rep(i, k) cout << ans[i] << '\n';

    return 0;
}

Submission #7171632 - AtCoder Beginner Contest 123

解答2 - 貪欲とpriority_queueを使う

 A, B, C をそれぞれ大きい順にsortしておくとする。

最大値は  A _ 0 + B _ 0 + C _ 0 である。この次に大きいのはどれかということを考える。しかし、一意に定まりそうではないので候補を絞ることにする。すると、次の候補は以下の3つであるということがわかる。

  •  A _ 1 + B _ 0 + C _ 0 A のindexを一つだけ進めた。
  •  A _ 0 + B _ 1 + C _ 0 B のindexを一つだけ進めた。
  •  A _ 0 + B _ 0 + C _ 1 C のindexを一つだけ進めた。

なので、現在の最大値をpopした上で、これらを全て priority_queue にpushする。すると、次はtopにあるものが2番目に大きいものとなり、順に大きいものをpopしていくことができる。計算量は、priority_queue からpopする回数が  K 回、上記のように候補を3つpushする回数が  3K 回なので、 O(K log(K)) で間に合う。

ここで実装上の注意点は以下のようになる。

  • priority_queue にそれぞれのindexも含めて管理する。なぜなら、popした後にindexを一つ進める作業が必要になるため、和の値とそれぞれのindexの情報が必要。
  • 一度見たindexの和の値は priority_queue にpushしないように気をつける。例えば  A _ 1 + B _ 0 + C _ 1 は以下の二つのものから到達可能であり、重複して出力してしまう可能性があるからだ。
    •  A _ 1 + B _ 0 + C _ 0 から、 C のindexを増やした時。
    •  A _ 0 + B _ 0 + C _ 1 から、 A のindexを増やした時。

この辺の実装は以下の記事を参考にして実装させていただいた。

drken1215.hatenablog.com

#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using vl = vector<ll>;

#define rep(i, n) for(ll i = 0; i < (ll)(n); i++)
#define all(obj) (obj).begin(), (obj).end()

using data = pair<ll, vl>;

int main() {

    ll x, y, z, k;
    cin >> x >> y >> z >> k;
    vl a(x), b(y), c(z);
    rep(i, x) cin >> a[i];
    rep(i, y) cin >> b[i];
    rep(i, z) cin >> c[i];
    sort(all(a), greater<>()), sort(all(b), greater<>()), sort(all(c), greater<>());

    priority_queue<data> ans;
    set<data> used;
    ans.push(data(a[0] + b[0] + c[0], vl({0, 0, 0})));
    while (k-- > 0) {
        auto now = ans.top();
        ans.pop();
        cout << now.first << '\n';
        ll ia = now.second[0], ib = now.second[1], ic = now.second[2];

        data tmp;
        if (ia + 1 < a.size()) {
            tmp = data(a[ia + 1] + b[ib] + c[ic], vl({ia + 1, ib, ic}));
            if (used.find(tmp) == used.end()) {
                used.insert(tmp);
                ans.push(tmp);
            }
        }
        if (ib + 1 < b.size()) {
            tmp = data(a[ia] + b[ib + 1] + c[ic], vl({ia, ib + 1, ic}));
            if (used.find(tmp) == used.end()) {
                used.insert(tmp);
                ans.push(tmp);
            }
        }
        if (ic + 1 < c.size()) {
            tmp = data(a[ia] + b[ib] + c[ic + 1], vl({ia, ib, ic + 1}));
            if (used.find(tmp) == used.end()) {
                used.insert(tmp);
                ans.push(tmp);
            }
        }
    }
    return 0;
}

Submission #7171780 - AtCoder Beginner Contest 123

解答3 - K個以上になる境目の値を二分探索

ここでも  A, B, C をそれぞれ大きい順にsortしておくとする。

 K 個以上になる美味しさの合計の境目を二分探索で探索」し、二分探索内での判定方法として「美味しさの合計が  p 以上であるものが  K 個以上あるかどうか調べる」方法を考える。

まず後者は、以下のように枝刈りをすれば  O(K ^ 2) の計算量で抑えられる。

auto solve = [&](ll p) -> bool {
    ll cnt = 0;
    rep(i, x) { // ここは最大でK回ループがまわる
        rep(j, y) { // ここから下は最大でK回ループがまわる
            rep(l, z) {
                ll val = a[i] + b[j] + c[l];
                if (val < p) break;
                if (++cnt >= k) {
                    return true;
                }
            }
        }
    }
    return false;
};

なぜか???それは、 A, B, C をそれぞれ大きい順にsortしてあるので、以下の操作をすることにより  y, z の二重ループが高々  K 回しか回ることがないからだ。

  • 合計が  p より小さいなら一番ネストの深いループを抜ける。
  • 合計が  p 以上なら、カウントを一つ増やし、 K 以上になったら関数をreturnする。

よって、二分探索内での判定は可能になった。

最後に、境目がわかった後にどうすれば良いのかを考える。この境目を  Border とする。

  • 合計が   Border 以上のものは  K 個以上ある。(が、最大で何個あるかどうかはわからない...!)
  • 合計が  Border + 1 以上のものは  K 個より少ない。

よって、以下のようにして上位  K 個を求めることで間に合う。

  • 合計が  Border + 1 以上のものを全部列挙する。 これは先ほどの二分探索内での判定方法をほとんど同じ。
  • この個数が  K 個よりも少なければ、残りの美味しさは全て  Border であるので、それをpushする。

計算量は、二分探索に  O(log(A _ {max} + B _ {max} + C _ {max})) 、枝刈りに  O(K ^ 2 ) より、  O(K ^ 2 log(A _ {max} + B _ {max} + C _ {max})) で間に合う。

この考え方は公式の解説を参考にした。
https://img.atcoder.jp/abc123/editorial.pdf

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
using vl = vector<ll>;
#define rep(i, n) for(ll i = 0; i < (ll)(n); i++)
#define all(obj) (obj).begin(), (obj).end()

int main() {
    ll x, y, z, k;
    cin >> x >> y >> z >> k;
    vl a(x), b(y), c(z);
    rep(i, x) cin >> a[i];
    rep(i, y) cin >> b[i];
    rep(i, z) cin >> c[i];

    sort(all(a), greater<>()), sort(all(b), greater<>()), sort(all(c), greater<>());

    auto solve = [&](ll p) -> bool {
        ll cnt = 0;
        rep(i, x) { // ここは最大でK回ループがまわる
            rep(j, y) { // ここから下は最大でK回ループがまわる
                rep(l, z) {
                    ll val = a[i] + b[j] + c[l];
                    if (val < p) break;
                    if (++cnt >= k) {
                        return true;
                    }
                }
            }
        }
        return false;
    };

    ll s = -1, e = a[0] + b[0] + c[0] + 1;
    while (e - s > 1) {
        ll mid = (s + e) / 2;
        if (solve(mid)) s = mid;
        else e = mid;
    }
    vl ans;
    rep(i, x) {
        rep(j, y) {
            rep(l, z) {
                ll val = a[i] + b[j] + c[l];
                if (val < s + 1) break;
                ans.emplace_back(val);
            }
        }
    }
    while (ans.size() < k) ans.emplace_back(s);
    sort(all(ans), greater<>());
    for (auto val: ans) cout << val << '\n';
    return 0;
}

Submission #7171934 - AtCoder Beginner Contest 123