やりたいこと
確率質量関数(PMF: Probability Mass Function)を与えて、
一様でなく不均一な分布を用いて乱数を生成したい。
ちなみにPython
で書くとnumpy
を使って下記の通りに簡単に書ける。
import numpy as np
pmf = [0.22, 0.5, 0.28]
# 22%の確率で0, 50%の確率で1, 28%の確率で2
np.random.choice(3, p = pmf)
実装
1. naiveな実装
筆者がC
を使っていた頃に用いてた方法。
use rand::Rng;
use rand::rngs::ThreadRng;
use rand::thread_rng;
fn naive_impl(rng: &mut ThreadRng, pmf: &[f64]) -> usize {
let mut rep = 0;
let mut sum = 0.0;
let v: f64 = rng.gen();
for p in pmf.iter() {
sum += p;
if v < sum {
break;
}
rep += 1;
}
rep
}
fn main() {
let mut rng = rand::thread_rng();
let pmf = [0.22, 0.5, 0.28]; // 任意
let num = naive_impl(&mut rng, &pmf);
}
2. WeightedIndex
を使う方法
use rand::rngs::ThreadRng;
use rand::distributions::{Distributions, WeightedIndex};
fn main() {
let mut rng = rand::thread_rng();
let pmf = [0.22, 0.5, 0.28]; // 任意
let dist = WeightedIndex::new(&pmf).unwrap();
let num = dist.sample(&mut rng);
}
実測
1億回実行して速度を測ってみる。
実際に望む通りに乱数が生成されているかもチェックする。
use rand::Rng;
use rand::rngs::ThreadRng;
use rand::distributions::Distribution;
use rand::distributions::WeightedIndex;
use std::time::Instant;
fn naive_impl(rng: &mut ThreadRng, pmf: &[f64]) -> usize {
let v: f64 = rng.gen();
let mut sum = 0.0;
let mut rep = 0;
for c in pmf.iter() {
sum += c;
if v < sum {
break;
}
rep += 1;
}
rep
}
fn main() {
let mut rng = rand::thread_rng();
let pmf = [0.22, 0.5, 0.28];
let trial = 100_000_000;
println!(“\n# Naive”);
let start = Instant::now();
for _ in 0..trial {
naive_impl(&mut rng, &pmf);
}
println!(“w/o insertion -> {:?}“, start.elapsed());
let start = Instant::now();
let mut c = [0; 3];
for _ in 0..trial {
c[naive_impl(&mut rng, &pmf)] += 1;
}
println!(“w/ insertion -> {:?} {:?}“, start.elapsed(), c);
println!(“\n# WeightedIndex”);
let start = Instant::now();
let dist = WeightedIndex::new(&pmf).unwrap();
for _ in 0..trial {
dist.sample(&mut rng);
}
println!(“w/o insertion -> {:?}“, start.elapsed());
let mut c = [0; 3];
let start = Instant::now();
let dist = WeightedIndex::new(&pmf).unwrap();
for _ in 0..trial {
c[dist.sample(&mut rng)] += 1;
}
println!(“w/ insertion -> {:?} {:?}“, start.elapsed(), c);
}
結果。
# Naive
w/o insertion -> 869.118501ms
w/ insertion -> 1.619900124s [21997730, 49998808, 28003462]
# WeightedIndex
w/o insertion -> 1.000642238s
w/ insertion -> 1.586865294s [21993674, 50010827, 27995499]
なぜ優位が逆転しているのかわからないが、
WeightedIndex
の方が簡潔なので推奨されると思う。