LoginSignup
29
24

More than 5 years have passed since last update.

[学習用実装]層別サンプリング(Stratified Sampling)をPythonで実装する(1)

Last updated at Posted at 2017-09-06

エントリ概要

 層別サンプリング(stratified sampling)は、母集団の分布を良く維持してサンプリングするための手法です。pythonでは、scikit-learn の StratifiedShuffleSplit および train_test_split で実装されています。普段、機械学習モデルを交差検証(cross validation)をする際によくお世話になっています。
 ツールとして使うのにとどまるのではなく、理解を深めるため、学習用のサンプルコードをpythonで実装し、無作為抽出(random sampling)と比較をしてみました。
 その結果、抽出元のサンプル数が少ない場合ほど、無作為サンプリングと比べてよい精度で抽出できる傾向が確認できました。(当たり前の結果ではありますが、自分の手で再現したことで理解を深めました)

層別サンプリング

 簡潔に言うと、偏ったサンプル構成の母集団からサンプリングする際に役立つ手法です。
 母集団を、「層」という小集団に分けます。
 その際、層ごとの分散はなるべく小さく、層間の分散はなるべく大きくなるように分けます。
 つまり、同じ属性を持ったサンプル同士でグループ化するわけです。

 たとえば、0 ~ 9 のいずれかの数字が書かれたカードが100枚あるとします。
 これをシャッフルします。

0 1 9 ・・・ 5 3 7 1

 これを同じ数字ごとにまとめます。数字なのでソートすればOKですね。

0 0 0 ・・・ 5 5 5 5 ・・・ 9 9 9 9

 このように同じ数字でグループ化したら、各グループから同じ比率でサンプリングします。
 たとえば、抽出率を10%とする場合、数字0~9の各グループから10%確率で無作為抽出をします。
 仮に0~9の各数字が10枚ずつある場合、次のように無作為抽出します。

0 0 0 ・・・ 0 → 1枚を無作為に抽出
1 1 1 ・・・ 1 → 1枚を無作為に抽出
2 2 2 ・・・ 2 → 1枚を無作為に抽出
・・・
9 9 9 ・・・ 9 → 1枚を無作為に抽出

 この層別サンプリングが効果を発揮するのは、偏ったサンプル構成の場合です。

偏ったサンプル構成からの抽出

分かりやすい例として、カード20枚のうち、「0」が2枚、残る18枚が「1」としましょう。

0 0 1 1 ・・・ 1 1 1 1 1

(1) 全体から無作為サンプリングする場合

 全体から無作為抽出をして、10枚選ぶとしましょう。抽出率は50%です。
 無作為なので、18枚の「1」の中から10枚を選び、「0」を抽出しない場合が起きます。
 そうなると、抽出後のサンプルには、「0」が1枚も含まれていないため、分布を維持できなくなります。

0 0 1 1 ・・・ 1 1 1 1 1

     ↓     ↓     ↓  20枚の中から10枚抽出

1 1 1 1 1 1 1 1 1 1

「0」 が一枚もないため、母集団の分布とは明らかに異なる!

 そこで層別サンプリングをします。

(2) 層別サンプリングの場合

 層別サンプリングの場合、次のように抽出します。

  • 「0」2枚の中から1枚を抽出(=抽出率50%)
  • 「1」18枚の中から9枚を抽出(=抽出率50%)
0 0 1 1 ・・・ 1 1 1 1 1

     ↓     ↓     ↓  20枚の中から10枚抽出

0 1 1 1 1 1 1 1 1 1

「0」 および「1」の構成比率が母集団と同じ=分布が同じ!

こうして抽出されたサンプルは、全体から無作為に抽出した場合と比べて、母集団の分布をよりよく維持しています。

層別サンプリングを実装する

 概要が分かったので、層別サンプリングを実装します。
 処理のロジックは次のようになります。

  • 与えられたサンプルがK個の層から構成されているとします。
  • 抽出比率を r とします。( 0 < r < 1 )
  • 各層から少なくとも1個は抽出します。
  • 各層から1個を抽出した後、層ごとに無作為抽出を行います。
  • 各層の層番号を i とします。(i=1,2,3,..,K)
  • 各層のサンプル数をN(i) 個とします。
  • 各層で無作為抽出する際の抽出率をX(i)とします。( 0 < X(i) < 1 ; i=1,2,3,..,K)
  • 各層では必ず1個抽出するので、残る N(i) - 1 個に対して、抽出率 X(i) で無作為抽出をします。
  • 抽出個数の関係は次のようになります。
  • 1 + X(i) x (N(i) - 1) = N(i) x r
  • よって、各層の無作為抽出率 X(i) は次のようになります。
  • X(i) =( N(i) x r - 1) / ( N(i) - 1)

以上のロジックを踏まえてサンプル実装したコードが以下です。

import numpy as np
import random

def extract_stratified_sampling_result(ratio, base_samples):
    u"""
    抽出比率を指定して、有限母集団から層別サンプリングを実施する。
    :param ratio: 抽出比率 0 ~ 1.0
    :param base_sample: 抽出元集団
    :return:
    """
    # 各数字のグループから、まず1個ずつ取り出す。
    # その後、各数字グループから無作為に抽出して、構成比率を母集団に近づける。
    # 各数字グループから1個取り出した後の抽出率を X(i) とする。 iはグループ番号である。
    # 各数字グループの個数を N(i) とする。
    # 抽出すべき個数は、 ratio x N(i) である。
    # すでに1個取り出してあるので、残る ( N(i) - 1 )個の中から抽出率 X(i) で無作為に取り出す。
    # よって、すでに取り出した1個とあわせると、ratio x N(i) 個になる。
    # X(i) x (N(i) - 1) + 1 = ratio x N(i)
    # X(i) = (ratio x N(i) - 1 )/(N(i) - 1) である。

    block_count = np.bincount(base_samples)
    x = (ratio * block_count - 1) / (block_count - 1)

    # サンプリングする際の乱数の閾値を計算する。
    # 閾値 = 1.0 - 各グループの抽出率
    # 乱数が閾値を超えたら抽出する。
    # あるグループの抽出率が 0.3 ならば、 1.0 - 0.3 = 0.7 、乱数が0.7以上であれば抽出することになる。
    # 各数字グループごとの抽出率を格納した配列 x を、
    # 各数字の個数分だけならべる。
    threshold = np.repeat(1.0 - x, block_count)

    # 元集合をソートしたときの、各要素のインデックスリスト
    # この順番に samples から取り出すと、ソートした結果になる。
    sorted_pos = np.argsort(base_samples)

    # 各数字グループの開始位置
    block_start = np.concatenate(([0], np.cumsum(block_count)[:-1]))

    # 発生させた乱数が 閾値 threshold を超えたら抽出される。
    threshold[block_start] = 0  # 各数字グループの最初の要素はかならず抽出
    extracted = []
    for i in range(len(base_samples)):
        each_rand = random.random()
        if each_rand > threshold[i]:
            pos = sorted_pos[i]
            extracted.append(base_samples[pos])
    extracted = np.array(extracted, dtype=np.int64)
    return extracted

 引数 base_samples には、整数のリストが格納されていることを想定しています。
 では各コードを見ていきましょう。
 まず、抽出元の base_samples のグループの構成を把握しましょう。
 そこで np.bincount() の出番です。リストを構成する各数字の個数を集計してくれます。


    block_count = np.bincount(base_samples)

 
 たとえば、 base_samples に、0 が 9 個、1 が91個入っている場合、block_countには次のような結果が返ってきます。

[ 9 91] 

 つまり、
block_coount[0] = 数字 0 の個数、
block_count[1] = 数字 1 の個数、
というわけです。
この block_count が、先ほどのロジックにおける各層の数字の個数 N(i) にあたります。
では、続いて各層の抽出率 X(i) を求めます。
抽出率 r は、引数 ratio にあたりますので、コードは次のようになります。

    x = (ratio * block_count - 1) / (block_count - 1)

block_count が numpy配列なので、計算結果もnumpy配列になります。
つまり、x の中身は次のようになります。

x[0] = 数字 0 の層の抽出率
x[1] = 数字 1 の層の抽出率

x が計算できたので、次は乱数の閾値を求めます。
乱数がこの値以上であれば、サンプルを取り出すというわけです。
よって、

各層の乱数の閾値 = 1.0 - X(i)  

となります。
たとえば、数字 0 の抽出率が 0.10 の場合、閾値は、

1.0 - 0.10 = 0.90

となります。

つまり、発生した乱数(0~1.0)が 0.90 以上の場合だけ、サンプルを取り出すというわけです。

こうして求めた各層の乱数の閾値を、各層のサンプルの数だけ繰り返して並べます。

    # 各数字グループごとの抽出率を格納した配列 x を、
    # 各数字の個数分だけならべる。
    threshold = np.repeat(1.0 - x, block_count)

x[0] = 0.20 , block_count[0] = 2
x[1] = 0.10 , block_count[1] = 18 と仮定すると、乱数閾値のリスト threshold は次のようになります。

threshold = [ 0.80, 0.80, 0.90, 0.90, 0.90, ... 0.90]

 次に、base_samples をソートした後のインデックスリストを取得しておきます。

    # 元集合をソートしたときの、各要素のインデックスリスト
    # この順番に samples から取り出すと、ソートした結果になる。
    sorted_pos = np.argsort(base_samples)

threshold と sorted_pos と base_samples を組み合わせて使うと、層別サンプリングができます。
たとえば、

  • threshold[0] : 層0 の先頭の数字の乱数閾値
  • sorted_pos[0] : 層0 の先頭の数字が base_samples 上で格納されている位置(=インデックス)

 という関係にあるので、最初に発生させた乱数が threshold[0] 以上であれば、層0 の先頭の要素が抽出されるというわけです。
 threshold,sorted_pos の走査位置を動かしていけば、層別サンプリングができるというわけです。

 後は、各層から必ず1個取り出す、というロジックのための処理をします。

    # 各数字グループの開始位置
    block_start = np.concatenate(([0], np.cumsum(block_count)[:-1]))

    # 発生させた乱数が 閾値 threshold を超えたら抽出される。
    threshold[block_start] = 0  # 各数字グループの最初の要素はかならず抽出

 block_start には、各層の先頭要素の位置が入っています。
 たとえば、0 が10個、1が90個の場合、次のようになります。

block_start = [0,10]

 数字 0 の層は、block_start[0] 番目から始まり、
 数字 1 の層は、block_start[1] 番目から始まるという意味です。

 この block_start を使い、各層の先頭要素の乱数閾値を 0 にします。
 乱数閾値が0 ということは、かならず抽出されるということです。

 そして、各層の先頭以降の乱数閾値は、 (1 - X(i)) になっているので、計算して求めた抽出率にしたがって無作為に抽出されることになります。

 エントリが長くなりそうなので、いったんここで締めくくります。

 次のエントリでは、単純な無作為抽出と層別サンプリングの性能を比較するサンプルコードを紹介します。

29
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
24