44
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

要約

多腕バンディット問題を Thompson Sampling で解いてみたよ。

多腕バンディット問題とは

(ベルヌーイバンディット (Bernoulli Bandit) の場合)
複数のスロットマシンがあって、それらをプレイすると、当たりか外れが出る。
スロットごとに当たりが出る確率は異なっているが、その値はわからない。
このとき、決められた回数のゲームプレイで、多く当たりを引きたい。

これがベルヌーイバンディットと呼ばれるのは、
確率 p で 1、 p-1 で 0 をとる離散分布はベルヌーイ分布 だからである。

解くイメージとしては、

  • 当たりがたくさん出るスロットをたくさんプレイしたい。
  • でも他のスロットもプレイして当たりやすさを調べておきたい
    みたいなことを同時に達成したい。

Thompson Sampling とは

多腕バンディット問題をとくアルゴリズムの一つ。

  • スロットマシンごとに当たり・外れの回数を数えておく。
  • スロットマシンごとにベータ分布 Be(a+1, b+1) から乱数を引く。
    (a はそのスロットマシンの当たり回数、b は外れ回数)
  • 引いた値が最大になるスロットマシンをプレイする。

ここでベータ分布なのは、このように確率分布がベルヌーイ分布の場合、
事前分布をベータ分布にすると、事後確率もベータ分布になる、
共役事前分布であるかららしい [2]

こんな感じで実装してみた。

self.results はSat, Fat だけでいいけど、後で使うかもしれないので保持している。

import scipy
import scipy.stats
import numpy.random
import collections
import pprint


# Banding machine simulator.
def getResult(prob):
    return scipy.stats.bernoulli.rvs(prob, size=1)[0]

# Multiarm Bandit Problem solver using Thompson Sampling.
class BernoulliBandit:
    def __init__ (self, numBandits):
        self.numBandits = numBandits
        self.results = dict([(i, []) for i in range(self.numBandits)])
        
    def getBandit(self):
        posteriors = []
        for b in range(self.numBandits):
            Sat = len([r for r in self.results[b] if r == 1])
            Fat = len([r for r in self.results[b] if r == 0])    
            posteriors.append(numpy.random.beta(Sat+1, Fat+1))
        return numpy.array(posteriors).argmax()
    
    def feed(self, selectedBandit, result):
        self.results[selectedBandit].append(result)
        
    def __str__(self):
        out = ""
        for b in range(self.numBandits):
            Sat = len([r for r in self.results[b] if r == 1])
            Fat = len([r for r in self.results[b] if r == 0])    
            out += "Bandit[%d]   Fails: %4d\t  Successes: %6d\n" % (b, Sat, Fat)
        return out

if __name__ == "__main__":
    # set parameters
    numBandits = 3
    numTrials = 5000
    rewards = [0.9, 0.9, 0.4]   # expected rewards (hidden)

    bandit = BernoulliBandit(numBandits)
    for t in range(numTrials):
        if t % 100 == 0:
            print t,
        b = bandit.getBandit()
        r = getResult(rewards[b])
        bandit.feed(b, r)
    print
    print bandit
    print "Rewards", rewards

実行結果はこの通り。なんか真のリワード期待値に会ってないんだが、
ナイーブな実装はこんなもんなんだろうか。あるいは実装をミスってるんだろうか。

0 100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600 1700 1800 1900 2000 2100 2200 2300 2400 2500 2600 2700 2800 2900 3000 3100 3200 3300 3400 3500 3600 3700 3800 3900 4000 4100 4200 4300 4400 4500 4600 4700 4800 4900
Bandit[0]   Fails:  250	  Successes:     40
Bandit[1]   Fails: 4261	  Successes:    446
Bandit[2]   Fails:    0	  Successes:      3

Rewards [0.9, 0.9, 0.4]

jupyter の gist

参考文献

  1. http://ibisml.org/archive/ibis2014/ibis2014_bandit.pdf
  2. http://jmlr.org/proceedings/papers/v23/agrawal12/agrawal12.pdf
44
39
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?