Posted at

多腕バンディット on Java

More than 5 years have passed since last update.

以前書いていてupするのを忘れていたのでupします。主にUCB1とその仲間達を使っています。

今の職場では使用機会ほぼないだろうな...


参考文献


多腕バンディットアルゴリズム概要


  • どの腕を選択するかは、知識活用項と探索項で算出される値により判断する

  • 知識は現在までに得られた獲得報酬の期待値

  • 探索は期待値の信頼性を高めるために実施する
     (当該腕の使用回数が少ない場合は信頼性が低いので優先的に使用する、など)

  • 知識活用と探索のバランスが各アルゴリズムにより異なる


参考コード


腕に関するコード


Arm.java

package CappedUCB;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;

public class Arm {
private double p; // 勝利確率
private long X; //今までの報酬の合計
private long X2; //今までの報酬の2乗の合計
private int n; //これまでの総プレイ回数
private double price;
private double max_price;

private int BENEFIT; //報酬の値
private final double GAP=0.99999998; // 忘却係数 memory gap
private final double VAR=5;
private double AVG;

public Arm(int BENEFIT, double price, double max_price, double avg){
this.BENEFIT = BENEFIT;
while(true){
this.p = Math.random();
if(this.p<0.4) break;
}
this.price=price;
this.max_price = max_price;
this.AVG = avg;
}

public Map<String, Double> play(){
Map<String, Double> res = new HashMap<>();

this.n++;
if(p >= Math.random()){//当たった場合
this.X += this.BENEFIT;
this.X2 += this.BENEFIT * this.BENEFIT;
double pprice = Math.floor( this.price*suvRate() ); // 提案価格
res.put("BENEFIT", (double)this.BENEFIT);
res.put("pprice", pprice);
return res;
}
double pprice = Math.floor( this.price*suvRate() ); // 提案価格
res.put("BENEFIT", (double)0);
res.put("pprice", pprice);

return res; //外れた場合
}

public Map<String, Double> vs(){
Map<String, Double> res = new HashMap<>();
long seed = System.currentTimeMillis();
Random rand = new Random(seed);
double cprice = rand.nextGaussian()*VAR+AVG;

this.n++;
if(this.price >= cprice){//当たった場合
this.X += this.BENEFIT;
this.X2 += this.BENEFIT * this.BENEFIT;

double pprice = 0.0;
if(suvRate()<0.95){
pprice = Math.floor( this.price*suvRate() ); // 提案価格
}else{
pprice = Math.floor( this.price*0.85 );
}
res.put("BENEFIT", (double)this.BENEFIT);
res.put("pprice", pprice);
return res;
}
double pprice = Math.floor( (this.max_price-this.price)*0.5 ); // 提案価格
res.put("BENEFIT", (double)0.0);
res.put("pprice", pprice);

return res; //外れた場合
}

// 事象発生は時刻に依存
public double UBC1(int t){
return (double)(this.X)/(this.n) + Math.sqrt(2*Math.log((double)t)/this.n);
}

// 時刻に関係なく当該事象が生じる確率が等しい
public double UBCpulas(int t){
return (double)(this.X)/(this.n+Math.pow(10, -6)) + Math.sqrt( 1/(this.n+Math.pow(10, -6)) );
}

// 時刻に関係なく当該事象が生じる確率が等しい+忘却係数付き
public double UBCpulasGap(int t){
return (double)(this.X)/(this.n+Math.pow(10, -9))*Math.pow(this.GAP, this.n)
+ Math.sqrt( 1/(this.n+Math.pow(10, -9)) ) - 0.1*this.price/this.max_price ;
}

public double UBC1Tuned(int t){
double aveX = (double)(this.X)/(this.n);
double Vjs = (double)(this.X2)/(this.n) - aveX * aveX + Math.sqrt(2*Math.log((double)t)/this.n);
return (double)(this.X)/(this.n) + Math.sqrt(Math.log((double)t)/this.n * Math.min(0.25, Vjs));
}

// ↓ここから入札用↓
public double index(){ // total revenue
return this.price*Math.min( this.X, this.n*(suvRate()+confRad()) );
}

public double confRad(){ //culcConfidenceRadius
double alph = Math.log(this.n+Math.pow(10, -6));
double denom = 1/(this.n+1);
return alph*denom + Math.pow(alph*suvRate()*denom, 0.5);
}

public double suvRate(){ //survival rate
return (double)(this.X/(this.n+Math.pow(10, -6)) );
}

public void print(int anum){
System.out.println("arm"+anum+": win_rate="+(double)this.X/this.n);
}
}



main側コード


Main.java

package CappedUCB;

import java.math.*;
import java.util.*;
import java.util.Map.Entry;

public class Main {
//当たった場合の報酬
static int BENEFIT = 1; //※修正
static int ARM_NUM=10; //腕の数
static int MAX_STEP=1000000;
static int PRICE_CAP=50;
static int MIN = 100; // 入札勝利数の下限

static int money = 0; //報酬額合計
static int count = 0; //ゲームカウント

static int avg=20;

public static void main(String[] args){

// 運用しつつ最適な腕を探す
double bprice = 0.0; // best price
Map<Double, Arm> armmap = new HashMap<>();
double nprice = 0.0; // 提案される価格
for(int t=0; t<MAX_STEP; t++){
if(t>=100000) PRICE_CAP=100; avg=35; //予算上限変更に対するロバスト性チェック
if(money<=MIN){
double price = culcPrice()*PRICE_CAP;
//double price = culcPrice();
if(!armmap.containsKey(price)){
armmap.put(price, new Arm(BENEFIT, price, PRICE_CAP, avg));
money +=armmap.get(price).vs().get("BENEFIT");

}else{
bprice = explore(armmap, t).get("price");
}
}else{
Map<String, Double> res = explore(armmap, t);
bprice = res.get("price");
nprice = res.get("pprice");
if(!armmap.containsKey(nprice) && nprice != 0.0){
armmap.put(nprice, new Arm(BENEFIT, nprice, PRICE_CAP, avg));
}
}
//プレイ後の報酬額
//if(count%100 == 0 ) System.out.println( count+" "+money+"(played:"+bprice+")" );
if(count%100 == 0 ) System.out.println( bprice );

count++;
}

for(Entry<Double, Arm> e:armmap.entrySet()){
e.getValue().print(e.getKey().intValue());
}

}

public static Map<String, Double> explore(Map<Double,Arm> arms, int t){
double eval = -1.0;
double evali = -1;
Map<String, Double> res = new HashMap<>();
//すべての腕の中から評価値の一番高いものを選ぶ
for(Entry<Double, Arm> e: arms.entrySet()){
double tmp = e.getValue().UBCpulasGap(t); //UBC1Tuned(t) index UBCpulas(t)
//System.out.println("arm="+anum+": tmp="+tmp);
if(eval < tmp){
eval = tmp;
evali = e.getKey();
}
}
//当該マシンをプレイ
money += arms.get(evali).vs().get("BENEFIT");
double pprice = arms.get(evali).vs().get("pprice");
res.put("price", evali);
res.put("pprice", pprice); // 提案価格
return res;
}

public static double culcPrice(){
double delta = Math.pow(( 1/(money+Math.pow(10, -6))+Math.pow(10,-6) * Math.log(money+Math.pow(10, -6))), 0.25);
double eps = Math.pow(money+Math.pow(10, -6), -0.25);
double alph = Math.pow((double)money/(count+Math.pow(10, -6)), 1-delta);
double gamma = Math.min( alph, Math.pow(Math.exp(1), -1) );
double m=0;

int l=0; double Rmax=0; double Rl=0;
double pl=0; double Sl=0; double plmax=0;

do{
pl = Math.pow(1+delta, -l);
m=delta*count/Math.log(1/eps)*Math.log(1+delta);

Sl = Math.exp(m);
Rl=pl*Sl;

if( Sl>=Math.pow(1+delta, -1)*gamma && Rl>=Rmax ){
Rmax=Rl; plmax=pl;
}
l++;
}while(pl>eps && Sl<(1+gamma)*alph && Rl>Math.pow(1+gamma, -2)*Rmax);
//System.out.println("plmax="+plmax);
return plmax;
}
}


UCB1-Tunedに関しては1度全ての餌場を漁ることが前提条件なので、未知の腕が大量に生成される可能性のある場合は使用難しいと思われます。


まとめ

1年近く前に遊びで書いたコードなので、ちょっとおかしな数値が出てそうな気がします。。

参考程度に眺めて頂ければと思います。

お手数ですが、おかしな点があればご指摘頂けると助かります。