0
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦39 ~ライブラリ、ノード作成~

Last updated at Posted at 2022-01-28

前回

今回の目標

Javaで深層学習ライブラリを作ります。

ここから本編

前回は、「相手の手数を少なくする」考え方と「予測される最終結果がより良い位置に置く」考え方を総合的に用いるAIを考えました。
そして、それらにつける重みについて局所探索法を用いて探しましたがうまくいきませんでした。

ここで、少し話は変わりますが今まで深層学習や学習済みモデルの使用は全てPythonを使って行っていました。理由は単純にライブラリや情報が豊富だからです。これまでの深層学習ではデータ量はせいぜい20試合ぶん程度でしたので実行速度の遅いPythonでも問題なく行えていました。
しかし局所探索法となると話は別で、膨大な試合数が必要となります。前回は最大で8000試合行いましたが、Pythonで実行したため非常に時間がかかりました。

そこで、JavaのライブラリTribuoを使うことで実行速度の改善を図りましたがMavenの環境構築が全く進んでいません。二か月くらい苦戦しています。調べたところc#及びc++にも深層学習ライブラリはありましたが(それぞれKelpNetとtiny-dnn)、こちらの環境構築も全くできておりません。

なので苦肉の策として、試合の実行をc++を呼び出して行うことで時短ができないか考えました。
しかし、試合中の探索をc++に任せるにしても、ニューラルネットワークの使用についてはやはりPythonが必要になります。つまり、一つの試合の中でc++とPythonを何度も往復することになり、手間もかかるし期待するほどの時短が得られるかも怪しいです。

じゃあもう高速な言語で深層学習ライブラリを自作すればいいのでは? という発想になりました。
Pythonのライブラリであっても中身はc++で動いていたりするそうなので、自作ライブラリではたとえJavaやc++を使っても既製品ほどの高速度は得られないと思います。しかし、オセロの試合を実行しながらニューラルネットワークも動かすとなると、「Pythonで低速に探索し、高速に予測する」よりも「Javaなどで高速に探索し、低速に予測する」ほうが総合的な時間では短くなるかもしれません。

こういった理由で、Javaで深層学習ライブラリを自作することにしました。
Javaを選んだ理由は、package機能でソース管理が楽そうであることと、単に使ってみたかったからです。

今回はノードの作成を行います。
なお、今度のことを考えての作成は現状難しそうですので後から作り直す可能性はあります。

フォルダ構造

作成するライブラリ名を仮に「MyNet」とし、そのディレクトリ構造を考えました。
ぱっと思いついたものを書いただけですし、変更する可能性はあります。

MyNet
├── layer     // 層に関するファイルを入れる
├── nodes     // ノードに関するファイルを入れる
│   └── out_function  // 活性化関数に関するファイルを入れる
└── optimzer  // 最適化関数に関するファイルを入れる

nodes

ノードに関するクラスなどを入れます。

out_function

活性化関数を入れます。
もっとうまい方法もあるかもしれませんが、できるだけ高速化することを考え、最初に設定したクラスのインスタンスを作成する方法を採用しました。

AF.java

使用する活性化関数を列挙します。

package nodes.out_function;

/**
 * enum class for designating activation function.
 * AF is a word omitted "Activation Function".
 */
public enum AF {
    RELU,
    SIGMOID,
    TANH
};

ActivatonFunction.java

全ての活性化関数の親クラスです。

package nodes.out_function;

/**
 * Activation function's base class.
 * All activation functions must extend this class.
 */
public class ActivationFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public ActivationFunction(){
        ;
    }

    /**
     * Actiation function execution.
     * @param sum sum of after linear transformation.
     * @return output array's element.
     */
    public double calcurate(double sum){
        return sum;
    }
}

Sigmoid.java

シグモイド関数。

package nodes.out_function;

import java.lang.Math;

/**
 * Sigmoid function.
 */
public class Sigmoid extends ActivationFunction {
    /**
     * Constructor for this class.
     * Nothing to do.
     */
    public Sigmoid(){
        ;
    }

    /**
     * Sigmoid function execution.
     * @param sum sum of after linear transformation.
     * @return output array's element.
     */
    public double calcurate(double sum){
        return 1 / (1 + Math.exp(sum));
    }
}

ReLu.java

ReLu関数。プログラムはSigmoid関数とほぼ同じですので省略します。

Tanh.java

tanh関数。ReLu同様、省略します。

Node.java

一つのノードを表すクラスです。
ここでは順方向伝播のみ実装しています。
層の定義もまだ考えていないので変更する可能性大です。

package nodes;

import java.util.LinkedList;
import nodes.out_function.*;

/**
 * One node class.
 */
public class Node{
    /** The number of inputs for this node. */
    int in = 1;
    /** The number of outputs for this node. */
    int out = 0;
    /** The list of weight for this node. */
    LinkedList<Double> w = new LinkedList<Double>();
    /** Activation function of this node. */
    ActivationFunction function;
    /** Array of output. */
    LinkedList<Double> output = new LinkedList<Double>();

    /**
     * Constructor for this class.
     * Number of inputs includes bias.
     * @param input_num Number of inputs
     * @param output_num Number of outputs
     * @param type Type of activation function.
     * @exception System.exit The specified activation function
     *                        does not exist or misspecified.
     */
    public Node(int input_num, int output_num, AF type){
        in += input_num;
        out = output_num;
        for (int i = 0; i < in; i++){
            w.add(1.);
        }

        switch (type) {
            case RELU:
                function = new ReLu();
                break;
            case SIGMOID:
                function = new Sigmoid();
                break;
            case TANH:
                function = new Tanh();
                break;
            default:
                System.out.println("ERROR: The specified activation function is wrong");
                System.exit(0);
        }
    }

    /**
     * Doing forward propagation.
     * @param input Array of inputs.
     * @return output array's element.
     */
    public double forward(LinkedList<Double> input){
        double sum = 0.;

        for (int i = 0; i < input.size(); i++){
            sum += input.get(i) * w.get(i);
        }

        return function.calcurate(sum);
    }
}

次回は

深層学習の知識もJavaの知識もまだまだ浅いので、とんでもないことを始めてしまったと若干後悔しています。
次回は層を作りたいです。

次回

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0