LoginSignup
0
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦47 ~Javaで勝敗予測~

Last updated at Posted at 2022-03-16

前回

今回の目標

自作した深層学習ライブラリを用いて、Javaでデータ集めと学習を行う。

使用バージョン

  • Java 17.0.1
  • MyNet 1.2.0

ここから本編

行なう学習の概要について、時間も空いてしまったので再度説明します。

私が行っていたのは、「盤面から最終結果を予測する」というものです。
その前は、マスの位置ごとに設定した評価値を用いた方法を考えていましたが、固定的な考え方では柔軟な対応が難しいのではないかと考え上記の方法を思いつきました。
最終結果というのは具体的に、「最終的な自分の駒の数-最終的な相手の駒の数」で、マイナスもあり得る数でした。しかし45 ~最適化関数~でReLu関数の有効性が示されたので、ここでは最終結果を「最終的な自分の駒の数」とすることにします。これで必ず正の数となりますし、マスの数は64固定ですので、一定の基準で勝ち具合を推し量ることができます。マスがすべて埋まらないままゲームが終わることもありますが、めったにないのでここでは無視することにします。

探索と予測を用いて、最終結果が良い方へ石を置いていき、「相手の取れる手数を減らしていく」アルゴリズムを搭載した「nleast」の打倒を目標にします。
また、nleastとこのAIを組み合わせることも以前は考えていました。
nleastについては22 ~遺伝的アルゴリズム、改善~#nleastで、組み合わせAIについては38 ~係数、局所探索~で詳しく解説しています。nleastの有用性は34 ~性能評価~で示されました。

Javaでオセロ

まず、JavaでBitBoardを使用したオセロができるようにします。一度26 ~JavaでBitBoard~で似たようなことをしていますが、拙いプログラムなので作り直します。
OseroBaseクラスで、オセロに必要な基本的なメソッドを実装し、Oseroクラスで各思考方法、そしてplayメソッドを実装することにしました。

OseroBase.java

これまでのオセロプログラムと比べて画期的な点はあまりないですが、最後にこういったオセロクラスを作成してから日もたっていますので再度解説します。

フィールドとコンストラクタ

OseroBase.java
public class OseroBase {
    public static final int SIZE = 8;      // オセロ盤面サイズ
    public static final int SHIFTNUM = 3;  // 8倍するときに使う(3ビットシフト)
    public long bw[] = new long[2];        // 盤面
    public boolean turn = false;           // ターン

    public OseroBase(){
        ;
    }

printBoard

盤面表示メソッド。
黒い石を「@」で、白い石を「O」で表示します。

OseroBase.java
    public void printBoard(){
        int num = 0;
        long place = 1;

        System.out.print("\n  ");
        for (int i = 0; i < 8; i++) System.out.printf(" %d ", i + 1);

        System.out.println("\n -------------------------");
        while (place != 0){
            if (num % 8 == 0) System.out.printf("%d", (num >> 3) + 1);
            if ((this.bw[0] & place) != 0) System.out.print("|@ ");
            else if ((this.bw[1] & place) != 0) System.out.print("|O ");
            else System.out.print("|  ");
            if (num % 8 == 7) System.out.println("|\n -------------------------");
            num++;
            place = place << 1;
        }
    }

popCount

立っているビット数を数えるメソッド。「population count」と呼ばれるアルゴリズムを使用しています。

OseroBase.java
    protected int popCount(long now){
        now = now - ((now >> 1) & 0x5555555555555555L);
        now = (now & 0x3333333333333333L) + ((now >> 2) & 0x3333333333333333L);
        now = (now + (now >> 4)) & 0x0f0f0f0f0f0f0f0fL;
        now = now + (now >> 8);
        now = now + (now >> 16);
        now = now + (now >> 32);

        return (int)now & 0x7f;
    }

setup

盤面を、最初の状態にするメソッド。

OseroBase.java
    public void setup(){
        this.turn = false;
        this.bw[0] = 0x810000000L;
        this.bw[1] = 0x1008000000L;
    }

checkAll

盤面に、置ける場所が一か所でもあるかどうか確認するメソッド。
一重ループに変更。

OseroBase.java
    protected boolean checkAll(){
        int i;
        int row = 0, col = 0;

        for (i = 0; i < OseroBase.SIZE << OseroBase.SHIFTNUM; i++){
            if (OseroBase.check(row, col, this.bw, this.turn)) return true;
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
        }

        return false;
    }

countLast

試合の最後に勝敗を判定するメソッド。

OseroBase.java
    public void countLast(){
        int black = this.popCount(this.bw[0]);
        int white = this.popCount(this.bw[1]);

        System.out.printf("black: %d, white: %d\n", black, white);

        if (black > white){
            System.out.println("black win!");
        }else if (white > black){
            System.out.println("white win!");
        }else{
            System.out.println("draw!");
        }
    }

check

指定した位置に置けるかどうかを判定するメソッド。一重ループに変更。
動作は以下の通り。

  1. 指定した場所に、すでにどちらかの石が置かれていたらfalseを返す。
  2. 指定した位置から、8方向について以下のことを考える。
  3. 指定した位置のとなりに相手の石があり、行・列ともに盤面の外に出ない限り、さらに隣を調べていく。
  4. もしその先で自分の石にたどり着いたら、ひっくり返すことができるのでtrueを返す。
  5. もしその先に何もなかった場合、他の方向を調べる。
  6. 8方向全てでひっくり返せなかった場合、falseを返す。
OseroBase.java
    protected static boolean check(int row, int col, long[] board, boolean turn){
        long place = 1L << (row << OseroBase.SHIFTNUM) + col;
        if ((board[0] & place) != 0) return false;
        if ((board[1] & place) != 0) return false;
        
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        int x = -1, y = -1;
        int focusRow, focusCol;
        for (int i = 0; i < 8; i++){
            focusRow = row + x; focusCol = col + y;
            place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;

            while ((board[opp] & place) != 0
                   && 0 <= x + focusRow && x + focusRow < OseroBase.SIZE
                   && 0 <= y + focusCol && y + focusCol < OseroBase.SIZE){
                focusRow += x;
                focusCol += y;
                place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;
                if ((board[my] & place) != 0){
                    return true;
                }
            }

            y++;
            if (y > 1){
                x++; y = -1;
            }else if (x == 0 && y == 0){
                y++;
            }
        }

        return false;
    }

put

指定した位置に置くメソッド。一重ループに変更。
このメソッドは、checkメソッドを利用し、指定した位置に置けることが確定している場合のみ呼び出す。
アルゴリズムは以下の通り。

  1. 指定した位置から、8方向について以下のことを考える。
  2. 指定した位置の隣に相手の石があり、行・列ともに盤面の外に出ない限り、さらに隣を調べていく。この時、調べた位置をinver変数に記録しておく。
  3. もしその先で自分の石にたどり着いたら、ひっくり返すことができるのでひっくり返す。
  4. もしその先に何もなかった場合、他の方向を調べる。
  5. 8方向全てについて行う。
OseroBase.java
    protected static void put(int row, int col, long[] board, boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        int x = -1, y = -1;
        int focusRow, focusCol;
        long inver, place;
        board[my] += 1L << (row << OseroBase.SHIFTNUM) + col;
        for (int i = 0; i < 8; i++){
            inver = 0;
            focusRow = row + x; focusCol = col + y;
            place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;

            while ((board[opp] & place) != 0
                   && 0 <= x + focusRow && x + focusRow < OseroBase.SIZE
                   && 0 <= y + focusCol && y + focusCol < OseroBase.SIZE){
                inver += place;
                focusRow += x;
                focusCol += y;
                place = 1L << (focusRow << OseroBase.SHIFTNUM) + focusCol;
                if ((board[my] & place) != 0){
                    board[my] += inver;
                    board[opp] -= inver;
                }
            }

            y++;
            if (y > 1){
                x++; y = -1;
            }else if (x == 0 && y == 0){
                y++;
            }
        }
    }

FourFunction.java

Osero.javaで使用。
これまでは、各思考方法で似たような探索を別のプログラムで記述していましたが、今回は共通部分である探索を一つのメソッドで行い、思考ごとに違うスコア計算部分を別で作ることにしました。
そこでメソッドをクラスに入れることになりました。
こちらはそのためのクラスです。
Javaにも関数ポインタ欲しいです。

FourFucntion.java
@FunctionalInterface
public interface FourFunction {
    public abstract double getScore(long board[], boolean nowTurn, boolean turn, int num);
}

Osero.java

OseroBaseクラスに思考方法を追加し、オセロができるようにしたクラス。なぜ分けたかというと、思考方法を別にまとめたほうが分かりやすくなると思ったからです。

フィールドとコンストラクタ

Osero.java
import java.util.Random;
import java.util.function.BiConsumer;
import java.io.*;
import java.util.ArrayList;
import java.util.Set;

public class Osero extends OseroBase {
    protected static Random rand = new Random(0);

    private static InputStreamReader isr = new InputStreamReader(System.in);
    protected static BufferedReader br = new BufferedReader(isr);

    protected ArrayList<BiConsumer<long[], Boolean>> playMethod = new ArrayList<BiConsumer<long[], Boolean>>();   // 黒と白の思考方法
    protected static int[] readGoal = new int[2];  // 何手先まで読むか
    protected static double[] customScore = {
         1.0, -0.6,  0.6,  0.4,  0.4,  0.6, -0.6,  1.0,
        -0.6, -0.8,  0.0,  0.0,  0.0,  0.0, -0.8, -0.6,
         0.6,  0.0,  0.8,  0.6,  0.6,  0.8,  0.0,  0.6,
         0.4,  0.0,  0.6,  0.0,  0.0,  0.6,  0.0,  0.4,
         0.4,  0.0,  0.6,  0.0,  0.0,  0.6,  0.0,  0.4,
         0.6,  0.0,  0.8,  0.6,  0.6,  0.8,  0.0,  0.6,
        -0.6, -0.8,  0.0,  0.0,  0.0,  0.0, -0.8, -0.6,
         1.0, -0.6,  0.6,  0.4,  0.4,  0.6, -0.6,  1.0
    };   // 評価値

    public static final boolean PRINT = true;
    public static final boolean NOPRINT = false;

    public Osero(){
        this.setup();
    }

    public Osero(ArrayList<BiConsumer<long[], Boolean>> playMethod){
        this.setup();
        this.playMethod = playMethod;
    }

    public Osero(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white){
        this.setup();
        this.playMethod.add(black);
        this.playMethod.add(white);
    }

フィールド変更メソッド

Osero.java
    public void setPlayMethod(ArrayList<BiConsumer<long[], Boolean>> p){
        if (p.size() != 2){
            System.out.println("playMethod size is wrong.");
            System.exit(-1);
        }
        this.playMethod = p;
    }

    public void setPlayMethod(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white){
        this.playMethod.clear();
        this.playMethod.add(black);
        this.playMethod.add(white);
    }

    public void setReadGoal(int[] r){
        if (r.length != 2){
            System.out.println("readGoal size is wrong.");
            System.exit(-1);
        }
        this.readGoal = r;
    }

    public void setReadGoal(int black, int white){
        this.readGoal[0] = black;
        this.readGoal[1] = white;
    }

    public void setCustomScore(double[] c){
        if (c.length != 64){
            System.out.println("customScore's length is wrong.");
            System.exit(-1);
        }
        this.customScore = c;
    }

    public void setRandom(long seed){
        this.rand = new Random(seed);
    }

exploreAssist

探索のひな型となるメソッド。スコア獲得メソッドを引数として迎え入れ、使用しています。
もちろん、checkAllメソッドにより、最低一か所は置ける場所があることが確認されたうえでしか呼ばれないメソッドです。
以下、流れを説明します。

  1. すべての場所について、その位置に置けるか調べる。
  2. もしその位置に置けるなら、仮においてみて、スコアを得る。
  3. それが最大スコアとなるなら、その場所を記憶しておく。
  4. もしそれが最大スコアと同じになるなら、候補として記憶しておく。
  5. 最大スコアをとる場所が複数になる場合、その中からランダムで選ぶ。
  6. 置く。
Osero.java
    protected static void exploreAssist(long[] board, boolean turn, FourFunction func){
        double maxScore = -100.0;
        double score;
        int[] rowAns = new int[OseroBase.SIZE << 1];
        int[] colAns = new int[OseroBase.SIZE << 1];
        int placeNum = 0;
        long[] boardLeaf = new long[2];

        int row = -1, col = 0;
        for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
            if (!OseroBase.check(row, col, board, turn)) continue;System.out.printf("%d, %d\n", row+1, col+1);
            boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
            OseroBase.put(row, col, boardLeaf, turn);
            score = func.getScore(
                boardLeaf,
                turn,
                !turn,
                1
            );
            if (score > maxScore){
                maxScore = score;
                placeNum = 0;
                rowAns[0] = row;
                colAns[0] = col;
            }else if (score == maxScore){
                placeNum++;
                rowAns[placeNum] = row;
                colAns[placeNum] = col;
            }
        }

        if (placeNum > 1){
            int place = rand.nextInt(placeNum+1);
            rowAns[0] = rowAns[place];
            colAns[0] = colAns[place];
        }

        OseroBase.put(rowAns[0], colAns[0], board, turn);
    }

human

人が置くメソッド。

Osero.java
    public static void human(long board[], boolean turn){
        int row, col;
        String rowS, colS;

        do {
            try{
                System.out.print("row: ");
                rowS = br.readLine();
                System.out.print("col: ");
                colS = br.readLine();

                row = Integer.parseInt(rowS);
                col = Integer.parseInt(colS);
                row--; col--;
            }catch (Exception e){
                System.out.println("error. once choose.");
                continue;
            }
        }while (!OseroBase.check(row, col, board, turn));

        OseroBase.put(row, col, board, turn);
    }

random

ランダムに置くメソッド。

Osero.java
    public static void random(long board[], boolean turn){
        int row, col;

        do {
            row = Osero.rand.nextInt(OseroBase.SIZE);
            col = Osero.rand.nextInt(OseroBase.SIZE);
        }while (!OseroBase.check(row, col, board, turn));

        OseroBase.put(row, col, board, turn);
    }

nHand, nHandCustom

nHandは、nターン先での自分の石の数がより多くなる位置に置く。
nHandCustomは、nターン先での自分の評価値がより高くなる位置に置く。

Osero.java
    public static void nHand(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNHand);
    }

    public static void nHandCustom(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNHandCustom);
    }

exploreNHand, exploreNHandCustom

探索を行うメソッド。過去作で言うところのboard_addにあたる。
これらのメソッドは、exploreAssistによって石が置かれた盤面を受け取ります。
そして、探索ゴールまでターン数が進んでいればそのままその盤面のスコアを返します。
もしまだなら、再び探索を行います。
また、二手以上先は置ける場所がない可能性もあるので、placeNum変数がメソッドの終わりまで0だったとき(つまり置ける場所がなかった時)はその盤面のスコアをそのまま返します。
この二つのメソッドも似通っているので、exploreAssistのようにまとめようかとも思いましたが複雑になって余計分かりにくくなりそうなので止めました。

Osero.java
    protected static double exploreNHand(long[] board, boolean nowTurn, boolean turn, int num){
        if (num >= Osero.readGoal[(nowTurn ? 1:0)]) return Osero.count(board, nowTurn);

        int score = 0, placeNum = 0;
        int row = -1, col = 0;
        long[] boardLeaf = new long[2];
        //                          64 = 8 << 3
        for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
            if (!OseroBase.check(row, col, board, turn)) continue;
            placeNum += 1;
            boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
            OseroBase.put(row, col, boardLeaf, turn);
            score += Osero.exploreNHand(
                boardLeaf,
                nowTurn,
                !turn,
                num + 1    
            );
        }

        if (placeNum > 0) return (double)score / placeNum;
        else              return (double)Osero.count(board, turn);
    }

    protected static double exploreNHandCustom(long[] board, boolean nowTurn, boolean turn, int num){
        if (num >= Osero.readGoal[(nowTurn ? 1:0)]) return Osero.countCustom(board, nowTurn);

        double score = 0, placeNum = 0;
        int row = -1, col = 0;
        long[] boardLeaf = new long[2];
        for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
            if (!OseroBase.check(row, col, board, turn)) continue;
            placeNum += 1;
            boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
            OseroBase.put(row, col, boardLeaf, turn);
            score += Osero.exploreNHandCustom(
                boardLeaf,
                nowTurn,
                !turn,
                num + 1    
            );
        }

        if (placeNum > 0) return score / placeNum;
        else              return Osero.countCustom(board, turn);
    }

count, countCustom

石の数や評価値の値を計算し返す。

Osero.java
    protected static int count(long board[], boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        long place = 1;
        int score;
        while (place != 0){
            if      ((board[my] & place) != 0)  score++;
            else if ((board[opp] & place) != 0) score--;
            place = place << 1;
        }

        return score;
    }

    protected static double countCustom(long board[], boolean turn){
        int my, opp;
        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        long place = 1;
        double score = 0;
        int i = 0;
        while (place != 0){
            if      ((board[my] & place) != 0)  score += Osero.customScore[i];
            else if ((board[opp] & place) != 0) score -= Osero.customScore[i];
            place = place << 1;
            i++;
        }

        return score;
    }

nLeast, nMost

nLeastは、n回先の相手のターンで相手のとれる手数が最も少ない位置に置く。
nMostは、n回先の自分のターンで自分のとれる手数が最も多い位置に置く。

Osero.java
    public static void nLeast(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNLeast);
    }

    public static void nMost(long board[], boolean turn){
        Osero.exploreAssist(board, turn, Osero::exploreNMost);
    }

exploreNLeast, exploreNMost

取れる手数を返すメソッド。過去作で言うところのcheck_placeにあたる。
これまでは一つのメソッドにまとめていましたが、かなり分かりづらいプログラムになっていたので分けました。それでもまだかなり分かりづらいですが・・・。
一応、一つ一つのメソッド自体はこれまでより短くすることができました。

まず、exploreNLeastから見ていきます。
nLeastはn回先の相手のターンで取れる手数が最小になることを目指す思考方法です。
このメソッドは、exploreAssistによって石が置かれた盤面を受け取ります。
つまり、現在のターン目線で言うと、自分が置いたそのあとの相手のターンです。
このとき、探索ゴールまで進んでいれば置ける数を数えて返します。この時マイナスにしているのは、exploreAssistメソッドではスコアが最大値になるところを目指すからです。相手のとれる手数は少ないほうがいいので、マイナスしておきます。
探索ゴールまで進んでいない場合は、さらに盤面に石を置いて探索を進めます。
このとき、調べているターンが自分のターンなら数字を足します。nLeastはn回先の相手のターンでとれる手数を考えますから、次が相手のターンとなるタイミングでカウントアップします。

Osero.java
    public static double exploreNLeast(long board[], boolean nowTurn, boolean turn, int num){
        int row = -1, col = 0;
        int placeNum = 0;
        double score = 0.;
        long[] boardLeaf = new long[2];

        if (num >= Osero.readGoal[(nowTurn ? 1:0)]){
            for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
                row++;
                if (row >= OseroBase.SIZE){
                    row = 0;
                    col++;
                }
                if (Osero.check(row, col, board, turn)) placeNum++;
            }

            return -(double)placeNum;
        }

        int nextNum;

        if (nowTurn == turn){
            nextNum = num + 1;
        }else{
            nextNum = num;
        }
        for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
            if (!Osero.check(row, col, board, turn)) continue;
            boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
            placeNum++;
            score += Osero.exploreNLeast(
                boardLeaf,
                nowTurn,
                !turn,
                nextNum
            );
        }

        return placeNum > 0 ? -score / placeNum : 0;
    }

次に、exploreNMostメソッド。
nMostはn回先の自分のターンで取れる手数が最大になることを目指す思考方法です。
もし、調べているのが自分のターンで、なおかつ探索ゴールまで探索できていた場合、取れる手数を数えて返します。
もしそうでなかった場合、さらに盤面に石を置いて探索を進めます。
このとき、調べているターンが自分のターンでなければ数字を足します。nMostはn回先の自分のターンでとれる手数を考えますから、次が自分のターンとなるタイミングでカウントアップします。

Osero.java
    public static double exploreNMost(long board[], boolean nowTurn, boolean turn, int num){
        int row = -1, col = 0;
        int placeNum = 0;
        double score = 0.;
        long[] boardLeaf = new long[2];

        if (nowTurn == turn && num >= Osero.readGoal[(nowTurn ? 1:0)]){
            for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
                row++;
                if (row >= OseroBase.SIZE){
                    row = 0;
                    col++;
                }
                if (Osero.check(row, col, board, turn)) placeNum++;
            }

            return (double)placeNum;
        }

        int nextNum;
        if (nowTurn != turn){
            nextNum = num + 1;
        }else{
            nextNum = num;
        }
        for (int place = 0; place < OseroBase.SIZE << OseroBase.SHIFTNUM; place++){
            row++;
            if (row >= OseroBase.SIZE){
                row = 0;
                col++;
            }
            if (!Osero.check(row, col, board, turn)) continue;
            boardLeaf[0] = board[0]; boardLeaf[1] = board[1];
            placeNum++;
            score += Osero.exploreNLeast(
                boardLeaf,
                nowTurn,
                !turn,
                nextNum
            );
        }

        return placeNum > 0 ? score / placeNum : 0;
    }

play

ゲームを行うメソッド。

Osero.java
    public void play(boolean printMode){
        boolean can = true, oldCan = true;

        if (printMode) this.printBoard();

        while ((can = this.checkAll()) || oldCan){
            if (can){
                this.playMethod.get(this.turn ? 1:0).accept(this.bw, this.turn);
                if (printMode) this.printBoard();
            }

            this.turn = !this.turn;
            oldCan = can;
        }

        this.countLast();
    }

    public void play(){
        boolean can = true, oldCan = true;

        while ((can = this.checkAll()) || oldCan){
            if (can){
                this.playMethod.get(this.turn ? 1:0).accept(this.bw, this.turn);
            }

            this.turn = !this.turn;
            oldCan = can;
        }
    }

勝敗予測

ここからやっと本題です。
試合を行うことでデータを集め、それをもとに学習を行います。

OseroData.java

データ集め用のクラス。
Osero.javaを継承し、データ集めを行います。

フィールドとコンストラクタ

OseroData.java
import java.util.function.BiConsumer;
import java.util.ArrayList;
import java.util.Random;

public class OseroData extends Osero {
    public ArrayList<ArrayList<Double>> history = new ArrayList<ArrayList<Double>>();  // 盤面履歴
    public ArrayList<ArrayList<Double>> result = new ArrayList<ArrayList<Double>>();   // 勝敗記録

    public OseroData(){
        this.setup();
    }

    public OseroData(ArrayList<BiConsumer<long[], Boolean>> playMethod){
        this.setup();
        this.playMethod = playMethod;
    }

    public OseroData(BiConsumer<long[], Boolean> black, BiConsumer<long[], Boolean> white){
        this.setup();
        this.playMethod.add(black);
        this.playMethod.add(white);
    }

dataClear

データ初期化メソッド。

OseroData.java
    public void dataClear(){
        this.history.clear();
        this.result.clear();
    }

writeHistory

与えられたeleに、現在の盤面情報を書き込むメソッド。

OseroData.java
    private void writeHistory(boolean turn, ArrayList<Double> ele){
        long place = 1;
        int my, opp;
        boolean myStone, oppStone;

        if (turn){
            my = 1; opp = 0;
        }else{
            my = 0; opp = 1;
        }

        while (place != 0){
            ele.add((myStone = ((this.bw[my] & place) != 0)) ? 1.0 : 0.0);    // 自分の石
            ele.add((oppStone = ((this.bw[opp] & place) != 0)) ? 1.0 : 0.0);  // 相手の石
            ele.add(myStone || oppStone ? 0.0 : 1.0);  // どちらの石も置かれていない時のみ1.0
            place = place << 1;
        }
    }

addData

指定した回数の試合を行い、その分のデータを蓄積するメソッド。

OseroData.java
    public void addData(int num){
        boolean can, oldCan;
        double blackScore, whiteScore;
        ArrayList<Boolean> turnHistory = new ArrayList<Boolean>();  // ターン履歴

        // 1ループで1試合
        for (int i = 0; i < num; i++){
            this.turn = false;
            can = true; oldCan = true;
            this.setup();
            turnHistory.clear();
            this.rand = new Random(i);

            while ((can = this.checkAll()) || oldCan){
                if (can){
                    this.playMethod.get(this.turn ? 1:0).accept(this.bw, this.turn);
                    turnHistory.add(this.turn);

                    // 盤面履歴追加
                    ArrayList<Double> element = new ArrayList<Double>();
                    this.writeHistory(this.turn, element);
                    this.history.add(element);
                }

                this.turn = !this.turn;
                oldCan = can;
            }

            // 最終結果
            blackScore = (double)this.popCount(this.bw[0]);
            whiteScore = (double)this.popCount(this.bw[1]);

            // 黒のターンなら黒の最終石数、逆なら逆の結果を格納していく。
            for (Boolean turn: turnHistory){
                ArrayList<Double> element = new ArrayList<Double>();
                element.add(turn ? whiteScore : blackScore);
                this.result.add(element);
            }
        }
    }

Run.java

データ集めと学習を行います。mainメソッドを含むクラスです。

現在、人がかかわる以外の思考方法が5種類あるので、5x5の総当たり戦、25試合行ってデータ集めしようと思います。とりあえずreadGoalは2にしておきました。

45 ~最適化関数~で、層やノード数をやたら増やしても逆効果であることが分かったので、今回は2層で、ノード数は192、1とします。活性化関数はそれぞれReLu、Linearとします。また最適化関数はAdamを使います。

Run.java
import java.util.ArrayList;
import java.util.function.BiConsumer;

import org.MyNet.matrix.*;
import org.MyNet.network.*;
import org.MyNet.layer.*;
import org.MyNet.nodes.activationFunction.*;
import org.MyNet.optimizer.*;
import org.MyNet.costFunction.*;

public class Run {
    public static void main(String str[]){
        // 思考方法
        ArrayList<BiConsumer<long[], Boolean>> playMethod = new ArrayList<BiConsumer<long[], Boolean>>();
        playMethod.add(Osero::random);
        playMethod.add(Osero::nHand);
        playMethod.add(Osero::nHandCustom);
        playMethod.add(Osero::nLeast);
        playMethod.add(Osero::nMost);

        OseroData run = new OseroData();
        run.setReadGoal(2, 2);
        run.dataClear();

        // データ集め
        for (BiConsumer<long[], Boolean> black: playMethod){
            for (BiConsumer<long[], Boolean> white: playMethod){
                run.setPlayMethod(black, white);
                run.addData(1);
            }
        }

        // Deep Learning
        // MyNetで扱えるデータ型に変換
        Matrix X = new Matrix(run.history);
        Matrix T = new Matrix(run.result);
        T.div(64.0);   // 一応割る

        Network net = new Network(
            192,
            new Input(192, AF.RELU),
            new Output(1, AF.LINER)
        );
        Adam opt = new Adam(net, new MeanSquaredError());
        opt.fit(X, T, 10, X.row / 20);  // エポック数10、バッチサイズは全データサイズの5%
        net.save("osero.net");
    }
}

実行結果はこちら。

Epoch 1/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 2/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 3/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 4/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 5/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 6/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 7/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 8/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 9/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000
Epoch 10/10
loss: 2660162840360359000000000000000000000000000000000000000000000000000.0000

64で割っているのにこのエラー。
ネットワークの設定があまり良くなかったのか、いろいろ試したところ以下のネットワークで次の結果が得られました。

Run.java
        Network net = new Network(
            192,
            new Input(100, AF.TANH),
            new Dense(50, AF.TANH),
            new Output(1, AF.SIGMOID)
        );
Epoch 1/10
loss: 0.2345
Epoch 2/10
loss: 0.2329
Epoch 3/10
loss: 0.2315
Epoch 4/10
loss: 0.2303
Epoch 5/10
loss: 0.2291
Epoch 6/10
loss: 0.2280
Epoch 7/10
loss: 0.2270
Epoch 8/10
loss: 0.2260
Epoch 9/10
loss: 0.2251
Epoch 10/10
loss: 0.2242

正解ラベルが0~1ですので、出力層の活性化関数をSigmoidに変更したところ精度が高くなりました。また、層の数を増やしなおかつ1層に含まれるノード数を減らしたところ精度が上がりました。
しかしほとんど学習しておりません。

フルバージョン

次回は

今回はデータの前処理などを全く行っていないので、学習が進むよう工夫したいと思います。
学習が進みやすいノード数や層の数も検討したいと思います。

次回

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0