はじめに
CodeQUEEN 2023 予選 (AtCoder Beginner Contest 308) - AtCoder お疲れさまでした。
C - Standings の浮動小数点の誤差評価問題が、計算機科学的に面白かったです。nok0 さん、kyopro_friends さんの公式解説がためになります。
解説はコードが少ないです。日本語よりコードの方が読みやすい人としては、 C++, Rust の複数の実装例を書いてみようと思いました。
問題
入出力例など詳しくは元の問題をどうぞ。
問題文 (要約)
$N$ 個の $(A_i, B_i)$ の組が与えられます。
確率 $\frac{A_i}{A_i+ B_i}$ が「大きい順」になるように $i$ を並び替えてください。
同じ確率が複数現れる場合は、その $i$ を「小さい順」に並べ替えてください。
制約
- $2 ≤ N ≤ 2 × 10^5$
- $0 ≤ A_i, B_i ≤ 10^9$
- $A_i + B_i ≥ 1$
- 入力される数値は全て整数
入力
$N$
$A_1~ ~ B_1$
$⋮$
$A_N~ ~ B_N$
出力
$1,…,N$ の番号をこの順に空白区切りで出力せよ。
実装
実装1. double のソートを行う (WA)
$i$ 番目の値 $\frac{A_i}{A_i + B_i}$ と、$j$ 番目の値 $\frac{A_j}{A_j + B_j}$ を比較してソート、というのがまず最初に考えられそうです。
v = [0, 1, 2, ..., N - 1]
という配列を用意し、
std::stable_sort()
にこの v
と $A_i/(A_i + B_i)$ を使った比較方法を渡して並び替える感じです。
auto fn = [&ab](int l, int r) -> bool { return ab[r] < ab[l]; };
stable_sort(v.begin(), v.end(), fn);
C++ 実装例 1
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
int n;
cin >> n;
vector<double> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
double a, b;
cin >> a >> b;
ab.push_back(a / (a + b));
}
vector<int> v;
v.reserve(n);
for (int i = 0; i < n; ++i) {
v.push_back(i);
}
auto fn = [&ab](int l, int r) -> bool { return ab[r] < ab[l]; };
stable_sort(v.begin(), v.end(), fn);
for (int i = 0; i < n; ++i) {
cout << v[i] + 1 << " ";
}
return 0;
}
Rust 実装例 1
use itertools::Itertools;
use proconio::input;
fn main() {
input! {
n: usize,
ab: [(f64, f64); n],
}
let mut v = Vec::with_capacity(n);
for (i, &(a, b)) in ab.iter().enumerate() {
v.push((i + 1, a / (a + b)));
}
v.sort_by(|(_, l), (_, r)| r.partial_cmp(l).unwrap());
let result = v.iter().map(|(i, _)| i).join(" ");
println!("{}", result);
}
しかしこれはテストを通過しません。というのは、double 型の精度が足りず、浮動小数点の下位部分が不安定になるためです。
問題文の制約 $0 ≤ A_i, B_i ≤ 10^9$ から、A, B ともに 32bit ギリギリくらい使うことが分かります。 $2^{10}=1024≒10^3$ ですから、$10^9≒2^{30}$ です。
$\frac{A_i}{A_i + B_i}$ の分母と分子にそれぞれ $10^9$ や $10^9-1$ といった大きな数が入るとすると、精度としては少なくとも $10^9 \times (2 \times 10^9)≒2^{61}$ くらい欲しくなりそうです。double 型は通常 64bit で、指数部も込みと考えると、この問題の数値上下関係を表現するには精度が足りなそうと分かります。 1
実装2. 整数で処理できる比較関数を書き、安定ソートする (AC)
$\frac{A_i}{A_i + B_i} < \frac{A_j}{A_j + B_j}$ が成り立つか double で調べるのは、誤差が現れて難しいです。
このままだと難しいですが、事前に $(A_i + B_i) (A_j + B_j)$ を掛け算すると、$A_i(A_j + B_j) < A_j(A_i + B_i)$ が成り立つかを調べるという話になります。整数演算ですので誤差が現れません。
A, B ともに 32bit ギリギリくらい使いますので、この掛け算のためには 64bit 整数型 int64_t
, long long
が必要になります。
- auto fn = [&ab](int l, int r) -> bool { return ab[r] < ab[l]; };
+ auto fn = [&ab](int l, int r) -> bool {
+ auto [la, lb] = ab[l];
+ auto [ra, rb] = ab[r];
+ return ra * (la + lb) < la * (ra + rb);
+ };
stable_sort(v.begin(), v.end(), fn);
C++ 実装例 2
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
int n;
cin >> n;
vector<pair<int64_t, int64_t>> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
int a, b;
cin >> a >> b;
ab.push_back(pair(a, b));
}
vector<int> v;
v.reserve(n);
for (int i = 0; i < n; ++i) {
v.push_back(i);
}
auto fn = [&ab](int l, int r) -> bool {
auto [la, lb] = ab[l];
auto [ra, rb] = ab[r];
return ra * (la + lb) < la * (ra + rb);
};
stable_sort(v.begin(), v.end(), fn);
for (int i = 0; i < n; ++i) {
cout << v[i] + 1 << " ";
}
return 0;
}
Rust 実装例 2
use itertools::Itertools;
use proconio::input;
fn main() {
input! {
n: usize,
ab: [(usize, usize); n],
}
let mut v = Vec::with_capacity(n);
for (i, &(a, b)) in ab.iter().enumerate() {
v.push((i + 1, a, b));
}
v.sort_by(|l, r| {
let (_, la, lb) = l;
let (_, ra, rb) = r;
(ra * (la + lb)).cmp(&(la * (ra + rb)))
});
let result = v.iter().map(|(i, _, _)| i).join(" ");
println!("{}", result);
}
AC です。お疲れさまでした。
実装3. 整数で処理できる比較関数を書き、安定でないソートをする (AC)
C++ の std::sort()
は安定ソートではありません。安定ソートは比較結果がイコールになるときに元の順序を維持します。安定でない場合はどうなるか分かりません。
たとえば [3.1, 4.1, 5.9, 2.6]
を小数点以下の値が昇順 1, 1, 6, 9
になるようにソートしたとします。安定ソートでは [3.1, 4.1, 2.6, 5.9]
と [3.1, 4.1]
の順番が維持されます。安定でないソートではこの保証がなく、[4.1, 3.1, 2.6, 5.9]
になってしまるかもしれません。
同じ確率が複数現れる場合は、その $i$ を「小さい順」に並べ替えてください。
今回の問題のように同じ値の順番を気にする場合は、安定ソート stable_sort()
をするのがおすすめです。
sort()
はイコールのときに負安定になります、言い換えると、イコールを避ければ大丈夫です。たとえば次のように、インデックスを比較の第二条件に加えれば良いです。
auto fn = [&ab](int l, int r) -> bool {
auto [la, lb] = ab[l];
auto [ra, rb] = ab[r];
- return ra * (la + lb) < la * (ra + rb);
+ return pair(ra * (la + lb), l) < pair(la * (ra + rb), r);
};
- stable_sort(v.begin(), v.end(), fn);
+ sort(v.begin(), v.end(), fn);
第一条件が一致するときに第二条件を使う、というときに pair
が使えます。if 文よりスッキリ書けます。
C++ 実装例 3
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
int n;
cin >> n;
vector<pair<int64_t, int64_t>> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
int a, b;
cin >> a >> b;
ab.push_back(pair(a, b));
}
vector<int> v;
v.reserve(n);
for (int i = 0; i < n; ++i) {
v.push_back(i);
}
auto fn = [&ab](int l, int r) -> bool {
auto [la, lb] = ab[l];
auto [ra, rb] = ab[r];
return pair(ra * (la + lb), l) < pair(la * (ra + rb), r);
};
sort(v.begin(), v.end(), fn);
for (int i = 0; i < n; ++i) {
cout << v[i] + 1 << " ";
}
return 0;
}
Rust 実装例 3
use itertools::Itertools;
use proconio::input;
fn main() {
input! {
n: usize,
ab: [(usize, usize); n],
}
let mut v = Vec::with_capacity(n);
for (i, &(a, b)) in ab.iter().enumerate() {
v.push((i + 1, a, b));
}
v.sort_unstable_by(|l, r| {
let (li, la, lb) = l;
let (ri, ra, rb) = r;
((ra * (la + lb)), li).cmp(&((la * (ra + rb)), ri))
});
let result = v.iter().map(|(i, _, _)| i).join(" ");
println!("{}", result);
}
実装4. 精度の高い実数を用いる (AC)
C++ gcc, clang では 80bit の long double
を使うことができます。精度が高く、最初の double
の型を差し替えるだけで通ります。
- vector<double> ab;
+ vector<long double> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
- double a, b;
+ long double a, b;
cin >> a >> b;
ab.push_back(a / (a + b));
}
C++ 実装例 4
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
int n;
cin >> n;
vector<long double> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
long double a, b;
cin >> a >> b;
ab.push_back(a / (a + b));
}
vector<int> v;
v.reserve(n);
for (int i = 0; i < n; ++i) {
v.push_back(i);
}
auto fn = [&ab](int l, int r) -> bool { return ab[r] < ab[l]; };
stable_sort(v.begin(), v.end(), fn);
for (int i = 0; i < n; ++i) {
cout << v[i] + 1 << " ";
}
return 0;
}
しかし微妙なところがありまして……。 long double
が 80bit というのは C++ の規格で決まっているわけではありません。gcc, clang がたまたまそうなっているだけです。たとえば Visual C++ の long double
は double
と同じ 64bit です。
この方針でコンパイラ依存でない実装を行うためには、今年発行予定の C++23 で追加される float128_t
待ち、となるはずです。 2
まあ、競技プログラミングの文脈では実行環境を前提にして 3 、一段階高精度な実数演算をしたいときには long double
を使う、でだいたい大丈夫だと思います。 さらに gcc, clang で __float128
を使っても通せました。
- vector<double> ab;
+ vector<__float128> ab;
ab.reserve(n);
for (int i = 0; i < n; ++i) {
- long double a, b;
+ int a, b;
cin >> a >> b;
- ab.push_back(a / (a + b));
+ ab.push_back(static_cast<__float128>(a) / (a + b));
}
Rust の実数型は記事を書いている 2023/07/05 時点では double 相当の f64
が一番高精度で、この方針は行いづらいです。実装例を省略します。
実装5. 大きな整数を掛け算して単純なソートをする (AC)
$\frac{A_i}{A_i + B_i}$ の計算で誤差が出るという話でした。
分母が一番大きくなる数は $\frac{A_i}{10^9 + 10^9}$、その次は $\frac{A_i}{10^9 + 10^9 - 1}$ です。
ここで分子を最小値の 1 としてみます。分子最小と分母最大は同時には成り立ちませんけれど、値の範囲確認のためえいやっと。そうするとこの 2つの値の差は次のようになります。
$\dfrac{1}{10^9 + 10^9} - \dfrac{1}{10^9 + 10^9 - 1} ≒ \dfrac{1}{4 \times 10^{18}}$
これより小さな差は現れません。だとすると、$4 \times 10^{18}$ 以上の値 $K$ を事前に掛け算しておけば良いです。 $\frac{K A_i}{A_i + B_i}$ の比較で解けます。
$K$ はある程度以上大きければなんでも良いです。 $10^{20}$ や $2^{64}$ など。
K だけで 64bit 使いますので、この方法で解くには桁あふれを防ぐために 128bit 整数型や多倍長整数型が必要になります。
Rust 実装例 5
use itertools::Itertools;
use proconio::input;
fn main() {
input! {
n: usize,
ab: [(i128, i128); n],
}
const FACTOR: i128 = (2 * 10i128.pow(9)).pow(2);
let mut v = Vec::with_capacity(n);
for (i, &(a, b)) in ab.iter().enumerate() {
v.push((i + 1, FACTOR * a / (a + b)));
}
v.sort_by_key(|(_, x)| -x);
let result = v.iter().map(|(i, _)| i).join(" ");
println!("{}", result);
}
C++ は省略します。
実装6. 有理数で解く (AC)
いかにも実数の誤差について考えましょうという問題です。有理数が使えるなら誤差は現れません。
Rust では num クレートにお任せすれば終わります。
Rust 実装例 6
use itertools::Itertools;
use num_rational::Ratio;
use proconio::input;
fn main() {
input! {
n: usize,
ab: [(i64, i64); n],
}
let mut v = Vec::with_capacity(n);
for (i, &(a, b)) in ab.iter().enumerate() {
v.push((i + 1, Ratio::new(a, a + b)));
}
v.sort_by(|(_, l), (_, r)| r.cmp(l));
let result = v.iter().map(|(i, _)| i).join(" ");
println!("{}", result);
}
C++ は省略します。
最後に
楽しい問題でした。ありがとうございました。
本記事の C++ のコードは clang-format BasedOnStyle: LLVM
で、Rust のコードは rustfmt それぞれ標準設定で書いたつもりです。
実装例が参考になれば幸いです。
-
IEEE 754 - Wikipedia には、桁のビット部が 52bit + ケチ表現と書いています。 52bit と正確には知らなくても、足りなそうということは分かると思います。 ↩
-
float128_t
が気軽に使えると、ABC308-C の難易度が C++ では簡単、他言語では大変みたいなことにもなりそうです ↩ -
競技プログラミングで実行環境を前提におくことは他にもあります。たとえば 64bit ビルドでないと Rust の
usize
が 64bit になりません。32bit ビルドでは同じコードでも精度が足りずに WA になる、ということが考えられます。 ↩