0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

コンピュータとオセロ対戦55 ~強化学習~

Last updated at Posted at 2022-07-12

今回の目標

強化学習する

ここから本編

なぜ急にこんな話になったのか,順を追って説明します.

  1. 今までの方法では強いAIを作るのは無理そう(参考)
  2. MyNet2でのGD以外の最適化関数実装が難航している
  3. 深層強化学習への興味がある(でも初めてなのでまずは深層じゃない方から)
  4. そもそもMyNet2で畳み込み層を導入したのは,畳み込み層を利用した深層強化学習がやりたいという思いがあったから

こういう理由から,唐突ではありますが方針転換となりました.

ディレクトリ構成

以下の構成で行います.

.
├── source/osero  // オセロパッケージ
└── // 実行クラス

オセロクラス

若干の変更が入ってはいますが,こちらの記事で作成したものとほぼ同じなのでdetailsにします.
大きな変更点としてはコメントを追加したことと,こちらの記事で逆効果と判明した1重for文を廃止したことです.

プログラム
OseroBase.java

package source.osero;

/**
 * Basic osero class.
 */
public class OseroBase {
    /** Size of board. */
    public static final int SIZE = 8;
    /** Constant number for calculation. */
    protected static final int SHIFTNUM = 3;
    /** BitBoard. */
    protected long bw[] = new long[2];
    /** turn. */
    protected boolean turn = false;

    /**
     * Constructor for child class.
     */
    protected OseroBase(){
        ;
    }

    /**
     * Find out if there is a place for it on this board.
     * @return Can put?
     */
    protected boolean checkAll(){
        for (int i = 0; i < OseroBase.SIZE; i++){
            for (int j = 0; j < OseroBase.SIZE; j++){
                if (OseroBase.check(i, j, this.bw, this.turn)){
                    return true;
                }
            }
        }

        return false;
    }

    /**
     * Find out if there is a place for it on this board.
     * @param board Now board.
     * @param turn Now turn.
     * @return Can put?
     */
    protected static boolean checkAll(long board[], boolean turn){
        for (int i = 0; i < OseroBase.SIZE; i++){
            for (int j = 0; j < OseroBase.SIZE; j++){
                if (OseroBase.check(i, j, board, turn)){
                    return true;
                }
            }
        }

        return false;
    }

    /**
     * Count standing bits from number.
     * @return Number of standing bits.
     */
    protected int popCount(long now){
        now = now - ((now >> 1) & 0x5555555555555555L);
        now = (now & 0x3333333333333333L) + ((now >> 2) & 0x3333333333333333L);
        now = (now + (now >> 4)) & 0x0f0f0f0f0f0f0f0fL;
        now = now + (now >> 8);
        now = now + (now >> 16);
        now = now + (now >> 32);

        return (int)now & 0x7f;
    }

    /**
     * Print this board.
     */
    public void printBoard(){
        int num = 0;
        long place = 1;

        System.out.print("\n  ");
        for (int i = 0; i < 8; i++) System.out.printf(" %d ", i + 1);

        System.out.println("\n -------------------------");
        while (place != 0){
            if (num % 8 == 0) System.out.printf("%d", (num >> 3) + 1);
            if ((this.bw[0] & place) != 0) System.out.print("|@ ");
            else if ((this.bw[1] & place) != 0) System.out.print("|O ");
            else System.out.print("|  ");
            if (num % 8 == 7) System.out.println("|\n -------------------------");
            num++;
            place = place << 1;
        }
    }

    /**
     * Initialize this board.
     */
    public void setup(){
        this.turn = false;
        this.bw[0] = 0x810000000L;
        this.bw[1] = 0x1008000000L;
    }

    /**
     * Count result.
     */
    public void countLast(){
        int black = this.popCount(this.bw[0]);
        int white = this.popCount(this.bw[1]);

        System.out.printf("black: %d, white: %d\n", black, white);

        if (black > white){
            System.out.println("black win!");
        }else if (white > black){
            System.out.println("white win!");
        }else{
            System.out.println("draw!");
        }
    }

    /**
     * Find out if it can be placed in that location.
     * @param row Number of row.
     * @param col Number of column.
     * @param board The board.
     * @param turn The turn.
     * @return Can put?
     */
    protected static boolean check(int row, int col, long[] board, boolean turn){
        long place = 1L << (row << OseroBase.SHIFTNUM) + col;
        if ((board[0] & place) != 0) return false;
        if ((board[1] & place) != 0) return false;
        
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        int focusRow, focusCol;
        for (int x = -1; x <= 1; x++){
            for (int y = -1; y <= 1; y++){
                if (x == 0 && y == 0) continue;
                focusRow = row + x; focusCol = col + y;
                place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;

                while ((board[opp] & place) != 0
                    && 0 <= x + focusRow && x + focusRow < OseroBase.SIZE
                    && 0 <= y + focusCol && y + focusCol < OseroBase.SIZE){
                    focusRow += x;
                    focusCol += y;
                    place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;
                    if ((board[my] & place) != 0){
                        return true;
                    }
                }
            }
        }

        return false;
    }

    /**
     * Put the board.
     * @param row Number of row.
     * @param col Number of column.
     * @param board The board.
     * @param turn The turn.
     */
    protected static void put(int row, int col, long[] board, boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        int focusRow, focusCol;
        long inver, place;
        board[my] += 1L << (row << OseroBase.SHIFTNUM) + col;
        for (int x = -1; x <= 1; x++){
            for (int y = -1; y <= 1; y++){
                if (x == 0 && y == 0) continue;
                inver = 0;
                focusRow = row + x; focusCol = col + y;
                place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;

                while ((board[opp] & place) != 0
                    && 0 <= x + focusRow && x + focusRow < OseroBase.SIZE
                    && 0 <= y + focusCol && y + focusCol < OseroBase.SIZE){
                    inver += place;
                    focusRow += x;
                    focusCol += y;
                    place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;
                    if ((board[my] & place) != 0){
                        board[my] += inver;
                        board[opp] -= inver;
                    }
                }
            }
        }
    }
}
Osero.java
package source.osero;

import java.util.Random;
import java.util.function.BiConsumer;
import java.io.*;
import java.util.ArrayList;

/**
 * Class for play osero.
 */
public class Osero extends OseroBase {
    protected static Random rand = new Random(0);

    private static InputStreamReader isr = new InputStreamReader(System.in);
    protected static BufferedReader br = new BufferedReader(isr);

    protected ArrayList<BiConsumer<long[], Boolean>> playMethods
        = new ArrayList<BiConsumer<long[], Boolean>>();
    protected static int[] readGoals = new int[2];
    protected static double[] customScore = {
         1.0, -0.6,  0.6,  0.4,  0.4,  0.6, -0.6,  1.0,
        -0.6, -0.8,  0.0,  0.0,  0.0,  0.0, -0.8, -0.6,
         0.6,  0.0,  0.8,  0.6,  0.6,  0.8,  0.0,  0.6,
         0.4,  0.0,  0.6,  0.0,  0.0,  0.6,  0.0,  0.4,
         0.4,  0.0,  0.6,  0.0,  0.0,  0.6,  0.0,  0.4,
         0.6,  0.0,  0.8,  0.6,  0.6,  0.8,  0.0,  0.6,
        -0.6, -0.8,  0.0,  0.0,  0.0,  0.0, -0.8, -0.6,
         1.0, -0.6,  0.6,  0.4,  0.4,  0.6, -0.6,  1.0
    };

    public static final boolean PRINT = true;
    public static final boolean NOPRINT = false;

    /**
     * Constructor for child class.
     */
    protected Osero(){
        this.setup();
    }

    /**
     * Constructor for play.
     * @param playMethods Playing methods of black and white.
     */
    public Osero(ArrayList<BiConsumer<long[], Boolean>> playMethods){
        this.setup();
        this.playMethods = playMethods;
    }

    /**
     * Constructor for play.
     * @param black Playing method of black.
     * @param white Playing method of white.
     */
    public Osero(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white){
        this.setup();
        this.playMethods.clear();
        this.playMethods.add(black);
        this.playMethods.add(white);
    }

    /**
     * Play osero.
     * @param printMode is print?
     */
    public void play(boolean printMode){
        boolean can = true, oldCan = true;
        this.setup();

        if (printMode) this.printBoard();

        while ((can = this.checkAll()) || oldCan){
            if (can){
                this.playMethods.get(this.turn ? 1:0).accept(this.bw, this.turn);
                if (printMode) this.printBoard();
            }

            this.turn = !this.turn;
            oldCan = can;
        }

        this.countLast();
    }

    /**
     * Play no printing.
     */
    public void play(){
        boolean can = true, oldCan = true;
        this.setup();

        while ((can = this.checkAll()) || oldCan){
            if (can){
                this.playMethods.get(this.turn ? 1:0).accept(this.bw, this.turn);
            }

            this.turn = !this.turn;
            oldCan = can;
        }
    }

    /**
     * Playing method (human).
     * @param board The board.
     * @param turn The turn.
     */
    public static void human(long board[], boolean turn){
        int row = 0, col = 0;
        String rowS, colS;

        do {
            try{
                System.out.print("row: ");
                rowS = br.readLine();
                System.out.print("col: ");
                colS = br.readLine();

                row = Integer.parseInt(rowS);
                col = Integer.parseInt(colS);
                row--; col--;
            }catch (Exception e){
                System.out.println("error. once choose.");
                continue;
            }
        }while (!OseroBase.check(row, col, board, turn));

        OseroBase.put(row, col, board, turn);
    }

    /**
     * Playing method (human).
     * @param board The board.
     * @param turn The turn.
     */
    public static void random(long board[], boolean turn){
        int row, col;

        do {
            row = Osero.rand.nextInt(OseroBase.SIZE);
            col = Osero.rand.nextInt(OseroBase.SIZE);
        }while (!OseroBase.check(row, col, board, turn));

        OseroBase.put(row, col, board, turn);
    }

    /**
     * Playing method (n hand).
     * @param board The board.
     * @param turn The turn.
     */
    public static void nHand(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNHand);
    }

    /**
     * Playing method (n hand custom).
     * @param board The board.
     * @param turn The turn.
     */
    public static void nHandCustom(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNHandCustom);
    }

    /**
     * Assist to explore the board.
     */
    protected static void exploreAssist(long[] board, boolean turn, FourFunction func){
        double maxScore = -100.0;
        double score;
        int[] rowAns = new int[OseroBase.SIZE << 1];
        int[] colAns = new int[OseroBase.SIZE << 1];
        int placeNum = 0;
        long[] boardLeaf = new long[2];

        for (int row = 0; row < OseroBase.SIZE; row++){
            for (int col = 0; col < OseroBase.SIZE; col++){
                if (!OseroBase.check(row, col, board, turn)) continue;
                boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
                OseroBase.put(row, col, boardLeaf, turn);
                score = func.getScore(
                    boardLeaf,
                    turn,
                    !turn,
                    1
                );
                if (score > maxScore){
                    maxScore = score;
                    placeNum = 0;
                    rowAns[0] = row;
                    colAns[0] = col;
                }else if (score == maxScore){
                    placeNum++;
                    rowAns[placeNum] = row;
                    colAns[placeNum] = col;
                }
            }
        }

        if (placeNum > 1){
            int place = rand.nextInt(placeNum+1);
            rowAns[0] = rowAns[place];
            colAns[0] = colAns[place];
        }

        OseroBase.put(rowAns[0], colAns[0], board, turn);
    }

    /**
     * Explore (n hand).
     * @param board The borad.
     * @param nowTurn The turn.
     * @param turn The now exploring turn.
     * @param num Number of now exploring.
     * @return score of the place.
     */
    protected static double exploreNHand(long[] board, boolean nowTurn, boolean turn, int num){
        if (num >= Osero.readGoals[(nowTurn ? 1:0)]) return Osero.count(board, nowTurn);

        int score = 0, placeNum = 0;
        long[] boardLeaf = new long[2];
        for (int row = 0; row < OseroBase.SIZE; row++){
            for (int col = 0; col < OseroBase.SIZE; col++){
                if (!OseroBase.check(row, col, board, turn)) continue;
                placeNum += 1;
                boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
                OseroBase.put(row, col, boardLeaf, turn);
                score += Osero.exploreNHand(
                    boardLeaf,
                    nowTurn,
                    !turn,
                    num + 1    
                );
            }
        }

        if (placeNum > 0) return (double)score / placeNum;
        else              return (double)Osero.count(board, turn);
    }

    /**
     * Explore (n hand custom).
     * @param board The borad.
     * @param nowTurn The turn.
     * @param turn The now exploring turn.
     * @param num Number of now exploring.
     * @return score of the place.
     */
    protected static double exploreNHandCustom(long[] board, boolean nowTurn, boolean turn, int num){
        if (num >= Osero.readGoals[(nowTurn ? 1:0)]) return Osero.countCustom(board, nowTurn);

        double score = 0, placeNum = 0;
        long[] boardLeaf = new long[2];
        for (int row = 0; row < OseroBase.SIZE; row++){
            for (int col = 0; col < OseroBase.SIZE; col++){
                if (!OseroBase.check(row, col, board, turn)) continue;
                placeNum += 1;
                boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
                OseroBase.put(row, col, boardLeaf, turn);
                score += Osero.exploreNHandCustom(
                    boardLeaf,
                    nowTurn,
                    !turn,
                    num + 1    
                );
            }
        }

        if (placeNum > 0) return score / placeNum;
        else              return Osero.countCustom(board, turn);
    }

    /**
     * Count score.
     * @param board The board.
     * @param turn The turn.
     * @return score.
     */
    protected static int count(long board[], boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        long place = 1;
        int score = 0;
        while (place != 0){
            if      ((board[my] & place) != 0)  score++;
            else if ((board[opp] & place) != 0) score--;
            place = place << 1;
        }

        return score;
    }

    /**
     * Count score be based on this customScore.
     * @param board The board.
     * @param turn The turn.
     * @return score.
     */
    protected static double countCustom(long board[], boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        long place = 1;
        double score = 0;
        int i = 0;
        while (place != 0){
            if      ((board[my] & place) != 0)  score += Osero.customScore[i];
            else if ((board[opp] & place) != 0) score -= Osero.customScore[i];
            place = place << 1;
            i++;
        }

        return score;
    }

    /**
     * Playing method (n least).
     * @param board The board.
     * @param turn The turn.
     */
    public static void nLeast(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNLeast);
    }

    /**
     * Playing method (n most).
     * @param board The board.
     * @param turn The turn.
     */
    public static void nMost(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNMost);
    }

    /**
     * Explore (n least).
     * @param board The borad.
     * @param nowTurn The turn.
     * @param turn The now exploring turn.
     * @param num Number of now exploring.
     * @return score of the place.
     */
    public static double exploreNLeast(long board[], boolean nowTurn, boolean turn, int num){
        int placeNum = 0;
        double score = 0.;
        long[] boardLeaf = new long[2];

        if (num >= Osero.readGoals[(nowTurn ? 1:0)]){
            for (int row = 0; row < OseroBase.SIZE; row++){
                for (int col = 0; col < OseroBase.SIZE; col++){
                    if (Osero.check(row, col, board, turn)){
                        placeNum++;
                    }
                }
            }

            return -(double)placeNum;
        }

        int nextNum;

        if (nowTurn == turn){
            nextNum = num + 1;
        }else{
            nextNum = num;
        }
        for (int row = 0; row < OseroBase.SIZE; row++){
            for (int col = 0; col < OseroBase.SIZE; col++){
                if (!Osero.check(row, col, board, turn)) continue;
                boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
                placeNum++;
                score += Osero.exploreNLeast(
                    boardLeaf,
                    nowTurn,
                    !turn,
                    nextNum
                );
            }
        }

        return placeNum > 0 ? -score / placeNum : 0;
    }

    /**
     * Explore (n most).
     * @param board The borad.
     * @param nowTurn The turn.
     * @param turn The now exploring turn.
     * @param num Number of now exploring.
     * @return score of the place.
     */
    public static double exploreNMost(long board[], boolean nowTurn, boolean turn, int num){
        int placeNum = 0;
        double score = 0.;
        long[] boardLeaf = new long[2];

        if (nowTurn == turn && num >= Osero.readGoals[(nowTurn ? 1:0)]){
            for (int row = 0; row < OseroBase.SIZE; row++){
                for (int col = 0; col < OseroBase.SIZE; col++){
                    if (Osero.check(row, col, board, turn)){
                        placeNum++;
                    }
                }
            }

            return (double)placeNum;
        }

        int nextNum;
        if (nowTurn != turn){
            nextNum = num + 1;
        }else{
            nextNum = num;
        }
        for (int row = 0; row < OseroBase.SIZE; row++){
            for (int col = 0; col < OseroBase.SIZE; col++){
                if (!Osero.check(row, col, board, turn)) continue;
                boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
                placeNum++;
                score += Osero.exploreNLeast(
                    boardLeaf,
                    nowTurn,
                    !turn,
                    nextNum
                );
            }
        }

        return placeNum > 0 ? score / placeNum : 0;
    }

    /**
     * Set playing methods.
     * @param p Playing methods of black and white.
     */
    public void setPlayMethods(ArrayList<BiConsumer<long[], Boolean>> p){
        if (p.size() != 2){
            System.out.println("playMethods size is wrong.");
            System.exit(-1);
        }
        this.playMethods = p;
    }

    /**
     * Set playing methods.
     * @param black Playing method of black.
     * @param white Playing method of white.
     */
    public void setPlayMethods(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white){
        this.playMethods.clear();
        this.playMethods.add(black);
        this.playMethods.add(white);
    }

    /**
     * Set read goals.
     * @param r Read goals of black and white.
     */
    public void setReadGoal(int[] r){
        if (r.length != 2){
            System.out.println("readGoals size is wrong.");
            System.exit(-1);
        }
        Osero.readGoals = r;
    }

    /**
     * Set read goals.
     * @param black Read goal of black.
     * @param white Read goal of white.
     */
    public void setReadGoal(int black, int white){
        Osero.readGoals[0] = black;
        Osero.readGoals[1] = white;
    }

    /**
     * Set custom score.
     * @param c Custom score.
     */
    public void setCustomScore(double[] c){
        if (c.length != 64){
            System.out.println("customScore's length is wrong.");
            System.exit(-1);
        }
        Osero.customScore = c;
    }

    /**
     * Set Number of random seed.
     * @param seed Number of seed.
     */
    public void setRandom(long seed){
        Osero.rand = new Random(seed);
    }
}

強化学習クラス

Oseroクラスを継承するOseroQLearningクラスを作成し,ここに強化学習しながら対戦を行う思考方法を追加しようと思います.

属性やコンストラクタなど

OseroQLearning.java
package source.osero;

import java.util.HashMap;
import java.util.ArrayList;
import java.util.function.BiConsumer;
import java.io.PrintWriter;
import java.io.IOException;

/**
 * Class for Q Learning.
 */
public class OseroQLearning extends Osero {
    /** Value of eta. */
    protected static final double ETA = 0.01;
    /** Vlaue of epsilon */
    protected static final double EPSILON = 0.05;
    /** Value of gamma. */
    protected static final double GAMMA = 0.99;
    /** Value of default quantity. */
    protected static final double DEFAULT_Q = 0.5;
    /** Table for quantity value. */
    protected static HashMap<Long[], Double> qTable;

    /**
     * Constructor for child class.
     */
    public OseroQLearning() {
        this.setup();
        this.playMethods = new ArrayList<BiConsumer<long[], Boolean>>();
        OseroQLearning.qTable = new HashMap<Long[], Double>();
    }

    /**
     * Constructor for play.
     * @param playMethods Playing methods of black and white.
     */
    public OseroQLearning(ArrayList<BiConsumer<long[], Boolean>> playMethods) {
        this.setup();
        this.playMethods = playMethods;
        OseroQLearning.qTable = new HashMap<Long[], Double>();
    }

    /**
     * Constructor for play.
     * @param black Playing method of black.
     * @param white Playing method of white.
     */
    public OseroQLearning(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white) {
        this.setup();
        this.playMethods.clear();
        this.playMethods.add(black);
        this.playMethods.add(white);
        OseroQLearning.qTable = new HashMap<Long[], Double>();
    }

qTableはQ値を格納する表です.現在,盤面を長さ2のlong配列であらわしているのですが,それを状態とし,置く位置(行と列)を行動としています.
属性DEFAULT_Qですが,こちらはqTableでまだ定義されていないキーにサクセスしようとしたときに返される値です.

getResult

試合結果を取得するメソッド.

OseroQLearning.java
    /**
     * Get result of game.
     * @return black score - white score
     */
    public int getResult(){
        return this.popCount(this.bw[0]) - this.popCount(this.bw[1]);
    }

getQValue

Q値を取得するメソッド.

OseroQLearning.java
    /**
     * Get quantity value.
     * @param pos Valid position.
     * @return Quantity value.
     */
    protected static double getQValue(ArrayList<Long> pos) {
        Double rtn = OseroQLearning.qTable.get(
            new Long[]{pos.get(0), pos.get(1), pos.get(2), pos.get(3)}
        );
        return rtn == null    ? OseroQLearning.DEFAULT_Q
            /* rtn != null */ : rtn;
    }

    /**
     * Get quantity value.
     * @param pos Valid position.
     * @return Quantity value.
     */
    protected static double getQValue(Long[] pos) {
        Double rtn = OseroQLearning.qTable.get(pos);
        return rtn == null    ? OseroQLearning.DEFAULT_Q
            /* rtn != null */ : rtn;
    }

outputQTable

Q値をファイル出力するメソッド.

OseroQLearning.java
    /**
     * Output this quantity table to csv file.
     * @param fileName File name to output.
     * @return Success?
     */
    public boolean outputQTable(String fileName) {
        try (
            PrintWriter fp = new PrintWriter(fileName);
        ) {
            fp.write("black,white,row,col,qValue\n");
            for (var key: OseroQLearning.qTable.keySet()) {
                fp.write(String.format(
                    "%d,%d,%d,%d,%f\n",
                    key[0],
                    key[1],
                    key[2],
                    key[3],
                    OseroQLearning.getQValue(key)
                ));
            }
        } catch (IOException e) {
            return false;
        }

        return true;
    }

getValidPositions

置くことができる位置を列挙するメソッド.
なお,qTableに与えられるキーの形で受け取ります.

OseroQLearning.java
    /**
     * Get all valid positions and the board.
     * @return All valid positions and the board.
     */
    protected static ArrayList<ArrayList<Long>> getValidPositions(long board[], boolean turn) {
        var rtn = new ArrayList<ArrayList<Long>>();
        int my, opp;
        if (turn) {
            my = 1; opp = 0;
        } else {
            my = 0; opp = 1;
        }

        for (int i = 0; i < Osero.SIZE; i++) {
            for (int j = 0; j < Osero.SIZE; j++) {
                if (Osero.check(i, j, board, turn)) {
                    var rtnElement = new ArrayList<Long>();
                    rtnElement.add(board[my]);
                    rtnElement.add(board[opp]);
                    rtnElement.add((long)i);
                    rtnElement.add((long)j);

                    rtn.add(rtnElement);
                }
            }
        }

        return rtn;
    }

exploreMaxQValue

Q値が最大となる場所の行と列を返すメソッドです.

OseroQLearning.java
    /**
     * Explore place having max quantity value.
     * @param board Now board.
     * @param turn Now turn.
     * @return Numbers of row and column having max quantity value.
     */
    protected static int[] exploreMaxQValue(long[] board, boolean turn){
        var validPositions = OseroQLearning.getValidPositions(board, turn);
        var rowAns = new ArrayList<Integer>();
        var colAns = new ArrayList<Integer>();
        int placeNum = 0;
        double qValue, maxQValue = -100.;

        for (var vp : validPositions) {
            qValue = OseroQLearning.getQValue(vp);
            if (qValue > maxQValue) {
                maxQValue = qValue;
                placeNum = 0;
                rowAns = new ArrayList<Integer>(){{
                    add((int)(long)vp.get(2));
                }};
                colAns = new ArrayList<Integer>(){{
                    add((int)(long)vp.get(3));
                }};
            } else if (qValue == maxQValue) {
                placeNum++;
                rowAns.add((int)(long)vp.get(2));
                colAns.add((int)(long)vp.get(3));
            }
        }

        if (placeNum > 1) {
            int place = rand.nextInt(placeNum+1);
            rowAns.set(0, rowAns.get(place));
            colAns.set(0, colAns.get(place));
        }

        if (rowAns.size() == 0) {
            return null;
        }

        return new int[]{rowAns.get(0), colAns.get(0)};
    }

updateQValue

Q値を更新するメソッド.
もしゲームが終了していたら以下の式でQ値を更新します.

\begin{align}
Q(s,a)_{new}&=Q(s,a)_{old}+\eta(reward-Q(s,a)_{old})
\end{align}

まだゴールしていない場合,通常のQ学習であれば以下の式でQ値を更新します.

\begin{align}
Q(s,a)_{new}&=Q(s,a)_{old}+\eta(reward+\gamma\ max(Q(s_{next},a)_{old})-Q(s,a)_{old})
\end{align}

しかしオセロの場合,自分が置いたその次は相手のターンです.相手のターンのQ値で自分のQ値を更新しても仕方がないのではないかと考えました.
そこで,ゲームがまだ終了していなかったときは相手が打てる手を全て調べて実際に打ってもらい,その次の状態,つまり次の自分のターンでもっとも高くなるQ値を集め,その平均を$max(Q(s_{next},a))$として使用することにしました.
ここで調べる相手のターンがなかった場合はQ値の更新はなし,次の自分のターンで打てる手がなかった場合はその手は考慮しないということにしました.

OseroQLearning.java
    /**
     * Update quantity value.
     * @param board Now board.
     * @param turn Now turn.
     * @param row The number of row.
     * @param col The number of column.
     */
    protected static void updateQValue(long[] board, boolean turn, int row, int col) {
        int my, opp;
        double newQ = 0;
        Long[] key;
        double oldQ;

        if (turn) {
            my = 1; opp = 0;
        } else {
            my = 0; opp = 1;
        }
        key = new Long[]
        {
            board[my],
            board[opp],
            (long)row,
            (long)col
        };
        oldQ = OseroQLearning.getQValue(key);

        // if game is end
        if (!Osero.checkAll(board, turn) && !Osero.checkAll(board, !turn)) {
            int score = Osero.count(board, turn);
            double reward;
            if (score > 0) {
                reward = 1.;
            } else if (score < 0) {
                reward = -1.;
            } else {
                reward = 0.;
            }

            newQ = oldQ + OseroQLearning.ETA * (reward - oldQ);
        } else {
            double qValueNext = 0;
            var validPositions = OseroQLearning.getValidPositions(board, !turn);
            if (validPositions.size() == 0) {
                return;
            }

            for (var vp : validPositions) {
                long[] boardLeaf = board.clone();

                Osero.put(
                    (int)(long)vp.get(2),
                    (int)(long)vp.get(3),
                    boardLeaf,
                    !turn
                );
                var rowCol = OseroQLearning.exploreMaxQValue(boardLeaf, turn);
                if (rowCol == null) {
                    continue;
                }
                Osero.put(rowCol[0], rowCol[1], boardLeaf, turn);
                qValueNext += getQValue(
                    new Long[]{
                        boardLeaf[my],
                        boardLeaf[opp],
                        (long)rowCol[0],
                        (long)rowCol[1]
                    }
                );
            }

            double maxQValueNext = qValueNext / validPositions.size();
            newQ = oldQ + OseroQLearning.ETA * (
                OseroQLearning.GAMMA * maxQValueNext - oldQ
            );
        }

        OseroQLearning.qTable.put(key, newQ);
    }

qLearning

Q値に従って置き,同時に学習も進めるメソッド.

OseroQLearning.java
    /**
     * Playing method (q learing).
     * @param board The board.
     * @param turn The turn.
     */
    public static void qLearning(long board[], boolean turn) {
        // epsilon
        if (Osero.rand.nextDouble() < OseroQLearning.EPSILON) {
            Osero.random(board, turn);
            return;
        }

        // think part
        int[] rowCol = OseroQLearning.exploreMaxQValue(board, turn);
        
        Osero.put(rowCol[0], rowCol[1], board, turn);

        // update quantity value
        OseroQLearning.updateQValue(board, turn, rowCol[0], rowCol[1]);
    }

実行クラス

上で作成したQ学習に,先手後手を後退しながらrandom,nHand,nHandCustom,nLeast,nMostと1000回戦ってもらいます.
全ての大戦が終わったら,その時のQ値をファイル出力してもらいます.

Run.java
import java.util.function.BiConsumer;

import java.io.PrintWriter;
import java.io.IOException;
import java.util.ArrayList;
import source.osero.Osero;
import source.osero.OseroQLearning;

public class Run {
    public static void main(String[] str) {
        final int PLAYNUM = 1000;
        final String FILENAME = "result.csv";
        final String QTABLE_FILENAME = "qTable.csv";
        BiConsumer<long[], Boolean> qLearning = OseroQLearning::qLearning;
        ArrayList<BiConsumer<long[], Boolean>> opps = new ArrayList<>(){
            {
                add(Osero::random);
                add(Osero::nHand);
                add(Osero::nHandCustom);
                add(Osero::nLeast);
                add(Osero::nMost);
            }
        };
        int winNum, stoneNum, result;
        var run = new OseroQLearning();
        run.setReadGoal(1, 1);

        try (
            PrintWriter fp = new PrintWriter(FILENAME);
        ) {
            fp.write("number,win,stone\n");

            for (int i = 0; i < PLAYNUM; i++) {
                System.out.printf("\r%d/%d", i+1, PLAYNUM);
                winNum = 0; stoneNum = 0;

                for (var opp : opps) {
                    run.setPlayMethods(qLearning, opp);
                    run.play();
                    result = run.getResult();
                    winNum += result > 0 ? 1 : -1;
                    stoneNum += result;

                    run.setPlayMethods(opp, qLearning);
                    run.play();
                    result = run.getResult();
                    winNum -= result > 0 ? 1 : -1;
                    stoneNum -= result;
                }

                fp.write(String.format(
                    "%d,%d,%d\n", i, winNum, stoneNum
                ));
            }

            System.out.println();
            run.outputQTable(QTABLE_FILENAME);
        } catch (IOException e) {
            System.out.println("\nIO Exception");
        }
    }
}

実行結果

いつも通りipynbでグラフを作ります.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

################################

df = pd.read_csv("result.csv")
df.head()

################################

plt.plot(df["win"])
plt.title("win num")
plt.xlabel("generation")
plt.ylabel("win num")
plt.savefig("pic/win_num.png")
plt.close()

################################

plt.plot(df["stone"])
plt.title("stone num")
plt.xlabel("generation")
plt.ylabel("stone num")
plt.savefig("pic/stone_num.png")
plt.close()

結果はこちら.

win_num.png
stone_num.png

勝率の改善は見られません.

考察

出力されたQ値を見ると,ファイルサイズはかなり大きいですがほとんど同じ値が入っています.
オセロで取りうる盤面は非常に多彩で,かつそこからとれる手数も多いです.つまり,getQValueで取ってきた値はほとんど初期値だったのではないでしょうか.
今回行った試合数は1000x5x2の一万試合ですが,それでも同じ盤面同じ行動にはめったに出会わないようです.
表を使うことによるオセロの強化学習はかなり無理があることが分かりました.

フルバージョン

次回は

表での強化学習はあまり意味がないと分かったので,深層強化学習に挑戦します.

次回

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?