問題はこちら
頂点数 の木が与えられる(木なので辺の数は )。
以下のような 個のクエリ が投げられるので、それぞれに対する答えを出力せよ。
- 求めるものは以下を満たす頂点 , , のペアの個数。
- 頂点 と を結ぶパスの中で、辺の最大の重みが を超えない。
考え方
以下のみの辺で連結になっている頂点数 が分かれば、その中を結ぶ頂点のペアの個数は だとわかる。よって、重み 以下の辺が出てきたら、その二つの頂点を同じグループに入れればいい。
このようなグループの構成は Union find木 を使えばいいということがわかる。
愚直に全てのクエリに対して順番にやると、毎回 Union find木を1から構成する必要があり、間に合わない。そこで、この問題を思い出す。
、クエリを先に読んでおいて、wが大きい順にクエリを処理する。こうすることで、既に存在する木に新たに条件を満たす辺を加えていくだけでよくなる。
ABC040-D 道路の老朽化対策について (500) - ひらめの日常
今回のポイントとして、 が小さい時に条件を満たすような頂点のペアは、それよりも大きい の時も条件を満たし、同じグループに属するということがある。なので、クエリを小さい順に処理する。こうすることで、上記の引用部分と同様に、すでに存在する木に新たに条件を満たす辺を加えていくだけでよくなる。
また、個数をカウントする方法も工夫が必要で累積和的に管理する。 が新たにグループになる時、今までの が属するそれぞれのグループで作れるペアの個数を引き、新たに作られたグループで作れるペアの個数を足す。
ペア u, v が新たに条件を満たす時。 count -= unionfind.size(u) * (unionfind.size(u) - 1) / 2; count -= unionfind.size(v) * (unionfind.size(v) - 1) / 2; unionfind.unite(u, v); count += unionfind.size(u) * (unionfind.size(v) - 1) / 2;
解答
#include <bits/stdc++.h> using namespace std; using ll = long long; #define rep(i, n) for(ll i = 0; i < (ll)(n); i++) class UnionFind { private: vector<ll> size; // グループに属する物の数. public: vector<ll> par; // 親 vector<ll> rank; // 木の深さ explicit UnionFind(unsigned int n) { par.resize(n); rank.resize(n); size.resize(n); rep(i, n) { par[i] = i; rank[i] = 0; size[i] = 1; } } // 木の根を求める ll find(ll x) { if (par[x] == x) { return x; } else { return par[x] = find(par[x]); } } // グループのサイズを求める. ll calc_size(ll x) { return size[find(x)]; } // xとyの属する集合を併合 void unite(ll x, ll y) { x = find(x); y = find(y); if (x == y) return; if (rank[x] < rank[y]) { par[x] = y; } else { par[y] = x; if (rank[x] == rank[y])rank[x]++; } size[x] = size[y] = size[x] + size[y]; } // xとyが同じ集合に属するか否か bool is_same(ll x, ll y) { return find(x) == find(y); } }; template<typename T> using minpq = priority_queue<T, vector<T>, greater<T>>; using P = pair<ll, ll>; using edge = pair<ll, P>; int main() { ll n, m; cin >> n >> m; minpq<edge> que; rep(i, n - 1) { ll u, v, c; cin >> u >> v >> c; u--, v--; que.push(edge(c, P(u, v))); } vector<P> q(m); rep(i, m) { cin >> q[i].first; q[i].second = i; } sort(q.begin(), q.end()); vector<ll> ans(m); UnionFind uf(n); ll cnt = 0; rep(i, m) { ll target = q[i].first, idx = q[i].second; while (!que.empty() && que.top().first <= target) { edge now = que.top(); ll u = now.second.first, v = now.second.second; ll s1 = uf.calc_size(u), s2 = uf.calc_size(v); cnt -= s1 * (s1 - 1) / 2; cnt -= s2 * (s2 - 1) / 2; uf.unite(u, v); ll s3 = uf.calc_size(u); cnt += s3 * (s3 - 1) / 2; que.pop(); } ans[idx] = cnt; } rep(i, m) cout << ans[i] << ' '; return 0; }