TL;DL;
強化学習の手法の一つであるQ学習を用いて、N=10 の秘書問題の最適戦略を再現できた。
秘書問題とは
Wikipedia より。
- 秘書を1人雇いたいとする。
- n 人が応募してきている。 n という人数は既知である。
- 応募者には順位が付けられ、複数の応募者が同じ順位になることはない(1位からn位まで重複無く順位付けできる)。
- 無作為な順序で1人ずつ面接を行う。次に誰を面接するかは常に同じ確率である。
- 毎回の面接後、その応募者を採用するか否かを即座に決定する。
- その応募者を採用するか否かは、それまで面接した応募者の相対的順位にのみ基づいて決定する。
- 不採用にした応募者を後から採用することはできない。
- このような状況で、最良の応募者を選択することが問題の目的である。
最良の応募者を選択できる確率を最大化する戦略とその効果は以下の通りである
- 最初の n/e 人の応募者は採用しない(注: e は自然対数の底)
- それ以降は、面接した応募者がそれまでの応募者の中で最良であった場合に採用する
- 最善の応募者を選択する確率は 1/e すなわち約 37% になり、 n に依存しない
例えば応募者が100人の場合、あてずっぽに採用すると 1% の確率でしか最良の人を採用できないが、上記の戦略で採用する場合、確率は 37% にまで跳ね上がる。自然対数の底が出てきたり、確率が n に依存しないという非自明な結果であるというのが非常に面白い。
機械学習ならではの考慮
秘書問題特有の考慮点があるのでまず指摘しよう。選考者は応募者の相対順位のみを知るが、その絶対値を知ることができない、という点だ。
これは当たり前で、例えば N = 10 の秘書問題を解こうとした際に、候補者それぞれが 0 〜 9 の順位を持つのだが、その数字をそのまま選考者に伝えるのはナンセンスである。もしそうだったら、選考者の最適戦略は順位0の応募者がきたらその人を採用するということになる。つまり、確率100%で最良の人を採用してしまうのだ。
機械学習以外の方法でプログラムをする際は、応募者の数値そのものは意図的に無視するようなアルゴリズムを組むのだが、機械学習の場合はそのような気の利いたことをしてくれず、与えられた情報全てを利活用してしまう。つまり機械学習ならではの問題として、如何に応募者の情報を伝えるか が挙げられる。
これに対する答えは。i 番目の応募者について選考者が知るのは、0 から i番目までの応募者の中で、何番目に良いか?という情報を作ればよい。例を以下に挙げる。
const rank = [2, 5, 3, 0, 4, 1]; // 実際の応募者の順位
const info = [0, 1, 1, 0, 3, 1]; // 選考者に渡す情報
最初の応募者はもちろん今までで一番良いので 0 になる。次の応募者は最初の応募者よりも順位が低いので 1 であり、その次の応募者は、自分を含めた 3人の候補者の中で真ん中の順位なので 1 である、といった具合だ。これを汎用的につくる関数の実装を以下に記載する。
const f = arr => arr.map((val, i) => arr.slice(0, i).filter(x => x < val).length);
Q学習
ここでは Q学習とは何かは記載しない。そうでなく実装時のTipsを記載する。
状態
状態は二つの情報で一意に指定できる。すなわち、何番目の選考かと、現在の応募者の相対順位である。この相対順位は前章で記載したヤツ。状態空間は N * N あれば十分。ただ、終状態を用意しておくのがなにかと便利なので、それを N+1 番目の状態に当てる。とるべきアクションは、採用か次に期待かの二つなので、Q-table の大きさは (N+1) * N * 2 くらいになる。以下のように三次元配列で Q-table を定義した。
const Q = _.range(N+1).map(()=>_.range(N).map(()=>[0, 0]));
これは若干無駄である。例えば Q[0][3] は絶対に0である。なぜかを少し考えてみてほしい。そう考えると必要なQ-value の空間量はもう少し減って、正確には N * ( Math.floor(N/2) + 1) / 2 + 1 だが、まあ、これでよしとしよう。N = 100 でも十分小さいのであまり最適化を求めても仕方ない。
報酬
報酬は、採用アクションを取った時に、その応募者が最良であった場合にのみ 1 を与え、それ以外の全てのアクションの報酬は 0 にする。
学習率と時間割引
本ゲームは、必ず 最大でも N 回の選択で終わる。そのため時間割引であるところのガンマは 1 にしてよいし、するべきだ。逆に時間割引をする理由はない。学習率アルファはまあ適当に。
結果と考察
N = 10 での Q-table の値をみて、最適戦略と Q-value が整合的であることを確認しよう。
数学的な最適戦略は、最初の Math.floor(e/N) 人は採用を見送り、それ以降は、一番の人が来たら採用するという物である。これをQ-table の言葉でいいかえると、 こうなる
a= 0, 1, 2, 3 の場合
全てのbに対して Q[a][b][0] > Q[a][b][1]
a= 4, 5, 6, 7, 8, 9 の場合
b= 0 なら Q[a][b][0] < Q[a][b][1]
b≠ 0 なら Q[a][b][0] > Q[a][b][1]
これを見てみよう。a = 0 の場合はこうなった。つまり、初手は見送るべきという判断を示しておりOK。Q[0][0]以外のQ-valueが更新されていないのも予想通りである。
[ [ 0.41892273583101775, 0.006451281920209778 ],
[ 0, 0 ],[ 0, 0 ],[ 0, 0 ],[ 0, 0 ],
[ 0, 0 ],[ 0, 0 ],[ 0, 0 ],[ 0, 0 ],[ 0, 0 ] ],
a = 1 の場合はこうなった。非ゼロのところは二つあるがどちらの場合もアクション0、つまり見送りを支持しており、これも最適戦略と整合的。
[ [ 0.3263895090355375, 0.001280082051113942 ],
[ 0.6916155581085179, 0 ],
[ 0, 0 ], [ 0, 0 ], [ 0, 0 ], [ 0, 0 ], [ 0, 0 ], [ 0, 0 ],
[ 0, 0 ],[ 0, 0 ] ],
同様な傾向が a=2, 3も見られた。よし。
傾向が変わるのは a=4 だ。Q[4][0] をみればわかるように、応募者の相対順位が 0の場合は、その応募者を採用することを示すように第二成分の方が大きくなる。その一方、応募者の相対順位がノンゼロの場合は、次の応募者に期待をして非採用を支持するようになっている。
[ [ 0.11300452798824881, 0.19231754158629458 ],
[ 0.8324836922436362, 0 ],
[ 0.5369728782611017, 0 ],
[ 0.23100726269451421, 0 ],
[ 0.5001731027337304, 0 ],
[ 0, 0 ], [ 0, 0 ], [ 0, 0 ], [ 0, 0 ], [ 0, 0 ] ],
同様の傾向がa =5 以降も続いた。
最後に、このQ-tableに基づいた最適戦略で、10万回実施した際に、最良の応募者を選べた確率は 39.8 % であった。これは数学的な予想と一致した。
ソースコード
Node.js で確認。サードパーティライブラリの lodash をインポートして使っている。
const _ = require('lodash');
// conceal some information
function information(arr, idx) {
if (idx > arr.length - 1) { // this is case for creating dummy state
return 0;
}
return arr.map((val, i) => arr.slice(0, i).filter(x => x < val).length)[idx];
}
class Env {
constructor(n) {
this.n = n;
}
reset() {
this.seq = _.shuffle(_.range(this.n));
this.cur = 0;
return {
availableActions: [0, 1],
state: [0, 0],
};
}
step(action) {
this.cur += 1;
if (action === 1) { // select this applicant
return {
state: [this.seq.length, 0], // final state whose Qval is always 0
finished: true,
reward: (this.seq[this.cur - 1] === 0) ? 1 : 0,
};
}
if (action === 0) { // try next applicant
return {
state: [this.cur, information(this.seq, this.cur)],
finished: false,
nextAction: (this.cur === this.seq.length - 1) ? [1] : [0, 1],
reward: 0,
};
}
throw new Error('cannnot reach here');
}
}
// choose action with boltzman
function selectActionB(q, state, actions) {
const picker = (prob) => {
const accum = prob.map((p, i) => prob.slice(0, i + 1).reduce((a, x) => a + x));
return () => {
const r = Math.random();
return accum.filter(x => x < r).length;
};
};
const myQ = q[state[0]][state[1]];
const summed = _.sum(actions.map(a => Math.E ** myQ[a]));
const prob = actions.map(a => (Math.E ** myQ[a]) / summed);
return actions[picker(prob)()];
}
// choose action of max-Q value
function selectActionOpt(q, state, actions) {
const values = actions.map(a => q[state[0]][state[1]][a]);
const idx = values.indexOf(Math.max(...values)); // find index of max Qvalus
// console.log(values, actions, actions[idx])
return actions[idx];
}
// update Q-values
function updateQ(q, before, action, after, reward) {
const alpha = 0.8;
const gamma = 0.999;
const nextMax = _.max(q[after[0]][after[1]]);
let target = q[before[0]][before[1]][action];
target = (1 - alpha) * target + alpha * (reward + gamma * nextMax);
q[before[0]][before[1]][action] = target;
}
// main
const N = 10;
// Q-table. shape of (N+1) * N * 2, note: N+1 th is final state
const Q = _.range(N + 1).map(() => _.range(N).map(() => [0, 0]));
const env = new Env(N);
// training
const nTrain = 1000000;
_.range(nTrain).map(() => {
let { availableActions, state } = env.reset();
for (;;) { // loop for one game.
const action = selectActionB(Q, state, availableActions);
const info = env.step(action);
updateQ(Q, state, action, info.state, info.reward);
availableActions = info.nextAction;
state = info.state;
if (info.finished) break;
}
});
console.log(Q);
// do it!
const nSample = 100000;
let stat = 0;
_.range(nSample).map(() => {
let { availableActions, state } = env.reset();
for (;;) { // loop for one game.
const action = selectActionOpt(Q, state, availableActions);
const info = env.step(action);
availableActions = info.nextAction;
state = info.state;
if (info.finished) {
stat += info.reward;
break;
}
}
});
console.log(stat / nSample);