Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

多腕バンディット問題

More than 5 years have passed since last update.

要約

多腕バンディット問題を Thompson Sampling で解いてみたよ。

多腕バンディット問題とは

(ベルヌーイバンディット (Bernoulli Bandit) の場合)
複数のスロットマシンがあって、それらをプレイすると、当たりか外れが出る。
スロットごとに当たりが出る確率は異なっているが、その値はわからない。
このとき、決められた回数のゲームプレイで、多く当たりを引きたい。

これがベルヌーイバンディットと呼ばれるのは、
確率 p で 1、 p-1 で 0 をとる離散分布はベルヌーイ分布 だからである。

解くイメージとしては、
* 当たりがたくさん出るスロットをたくさんプレイしたい。
* でも他のスロットもプレイして当たりやすさを調べておきたい
みたいなことを同時に達成したい。

Thompson Sampling とは

多腕バンディット問題をとくアルゴリズムの一つ。

  • スロットマシンごとに当たり・外れの回数を数えておく。
  • スロットマシンごとにベータ分布 Be(a+1, b+1) から乱数を引く。 (a はそのスロットマシンの当たり回数、b は外れ回数)
  • 引いた値が最大になるスロットマシンをプレイする。

ここでベータ分布なのは、このように確率分布がベルヌーイ分布の場合、
事前分布をベータ分布にすると、事後確率もベータ分布になる、
共役事前分布であるかららしい [2]

こんな感じで実装してみた。

self.results はSat, Fat だけでいいけど、後で使うかもしれないので保持している。

import scipy
import scipy.stats
import numpy.random
import collections
import pprint


# Banding machine simulator.
def getResult(prob):
    return scipy.stats.bernoulli.rvs(prob, size=1)[0]

# Multiarm Bandit Problem solver using Thompson Sampling.
class BernoulliBandit:
    def __init__ (self, numBandits):
        self.numBandits = numBandits
        self.results = dict([(i, []) for i in range(self.numBandits)])

    def getBandit(self):
        posteriors = []
        for b in range(self.numBandits):
            Sat = len([r for r in self.results[b] if r == 1])
            Fat = len([r for r in self.results[b] if r == 0])    
            posteriors.append(numpy.random.beta(Sat+1, Fat+1))
        return numpy.array(posteriors).argmax()

    def feed(self, selectedBandit, result):
        self.results[selectedBandit].append(result)

    def __str__(self):
        out = ""
        for b in range(self.numBandits):
            Sat = len([r for r in self.results[b] if r == 1])
            Fat = len([r for r in self.results[b] if r == 0])    
            out += "Bandit[%d]   Fails: %4d\t  Successes: %6d\n" % (b, Sat, Fat)
        return out

if __name__ == "__main__":
    # set parameters
    numBandits = 3
    numTrials = 5000
    rewards = [0.9, 0.9, 0.4]   # expected rewards (hidden)

    bandit = BernoulliBandit(numBandits)
    for t in range(numTrials):
        if t % 100 == 0:
            print t,
        b = bandit.getBandit()
        r = getResult(rewards[b])
        bandit.feed(b, r)
    print
    print bandit
    print "Rewards", rewards

実行結果はこの通り。なんか真のリワード期待値に会ってないんだが、
ナイーブな実装はこんなもんなんだろうか。あるいは実装をミスってるんだろうか。

0 100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600 1700 1800 1900 2000 2100 2200 2300 2400 2500 2600 2700 2800 2900 3000 3100 3200 3300 3400 3500 3600 3700 3800 3900 4000 4100 4200 4300 4400 4500 4600 4700 4800 4900
Bandit[0]   Fails:  250   Successes:     40
Bandit[1]   Fails: 4261   Successes:    446
Bandit[2]   Fails:    0   Successes:      3

Rewards [0.9, 0.9, 0.4]

jupyter の gist

参考文献

  1. http://ibisml.org/archive/ibis2014/ibis2014_bandit.pdf
  2. http://jmlr.org/proceedings/papers/v23/agrawal12/agrawal12.pdf
mzmttks
ハイラブル株式会社の代表取締役。博士(情報学) カエルの研究をしてました。会話の定量化や分析をしてます。
https://www.mzmttks.com/
hylable
対面とオンラインの話し合いを可視化するクラウドサービスを開発・運営するスタートアップです。
https://www.hylable.com
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away