LoginSignup
1
0

More than 1 year has passed since last update.

コンピュータとオセロ対戦46 ~モデルの保存とロード~

Last updated at Posted at 2022-03-11

前回

今回の目標

ネットワークモデルの保存・ロードができるようにする。

ここから本編

現在のディレクトリ構成

バージョン情報追加。

MyNet
├── META-INF
└── org
    └── MyNet
        ├── costFunction
        ├── layer
        ├── matrix
        ├── network
        ├── nodes
        │   └── activationFunction
        ├── optimizer
        ├── version
        └── module-info.java

バージョン情報

バージョン情報のみを保持するクラスを作成。
MANIFEST.MFでバージョン情報を保存するやり方もありますが、ネットワーククラス内でもバージョン情報を利用したいため、独自にVersionクラスを作成しStringで保持します。

Version.java
package org.MyNet.version;

/** Class for information of version. */
public class Version {
    /** Information of version. */
    public static final String version = "1.0.0";
}

module-info.java

module-info.java
module MyNet {
    exports MyNet.costFunction;
    exports MyNet.layer;
    exports MyNet.matrix;
    exports MyNet.network;
    exports MyNet.nodes;
    exports MyNet.nodes.activationFunction;
    exports MyNet.optimizer;
    exports MyNet.versrion;
}

修正点

GD以外のfitメソッドは全く同じなので、Optimizerクラスにfitメソッドを移行しました。
GDクラスのみ個別にfitメソッドを定義しました。
またその都合上、GDクラスのbackメソッドをOptimizerクラスに移動させました。
中身は変更していないので省略します。

また、ネットワーククラスにバージョン情報を追加しました。
ネットワークのロードの際、バージョンが違うことによる不具合を防ぐためです。

Network.java
    /** Information of version. */
    public final String version = Version.version;

また、バージョン情報クラスでもちらっと出ましたが、ディレクトリ構成を変えたのでパッケージ名も各自変更してあります。

保存・ロード実装

これらのメソッドはネットワーククラスにつけることにしました。
もちろん、ノードクラスなどにSerializableをimplementsしています。
loadおよびversionHasProblemメソッドについて、少しわかりづらいですが、versionHasProblemメソッドはあくまでロードしたネットワークのバージョンと現在使用しているクラスのバージョンが同じかどうかを確認し、もし問題があれば続行するか尋ねるだけです(セーブ当時のバージョンと、ロード時のバージョンは異なる恐れがあります)。
versionHasProblemメソッドは、バージョンが異なることが確認され、さらにユーザがエラーを恐れてプログラムの実行中止を判断したときのみにtrueを返します。もしバージョンが同じ、またはバージョンが違っても、(互換性のあるバージョン違いなどの理由で)ユーザが問題ないと判断した場合はfalseを返す、つまり問題なしと判断します。
loadParameterメソッドはその名の通り、ロードしたネットワークのパラメータを使用しているネットワークに入れます。

Network.java
    /**
     * Save this network to one file.
     * @param name Name of save file.
     */
    public void save(String name){
        try (
            FileOutputStream fos = new FileOutputStream(name);
            ObjectOutputStream oos = new ObjectOutputStream(fos);
        ){
            oos.writeObject(this);
            oos.flush();
        }catch (IOException e){
            System.out.println("IOException");
            System.exit(-1);
        }
    }

    /**
     * Load a network from the file.
     * @param name Name of load file.
     */
    public void load(String name){
        try (
            FileInputStream fis = new FileInputStream(name);
            ObjectInputStream ois = new ObjectInputStream(fis);
        ){
            Network loadNet = (Network)ois.readObject();
            if (this.versionHasProblem(loadNet)){
                System.exit(-1);
            }else{
                ;
            }
            this.loadParameters(loadNet);
        }catch (IOException e){
            System.out.println("IOException");
            System.exit(-1);
        }catch (ClassNotFoundException e){
            System.out.println("ClassNotFoundException");
            System.exit(-1);
        }
    }

    /**
     * Check that the information of version has problem.
     * If the two networks have different versions, ask do you want to continue.
     * @param loaded Loaded network.
     * @return Do loaded network has problem?
     */
    protected boolean versionHasProblem(Network loaded){
        if (!this.version.equals(loaded.version)){
            System.out.println("The versions of the two Network classes do not match.");
            System.out.printf("The version of loaded is %s\n", loaded.version);
            System.out.printf("The version of this is %s\n", this.version);
            System.out.println("There is a risk of serious error.");
            System.out.print("Do you want to continue? [y/n] ");

            Scanner sc = new Scanner(System.in);
            String ans = this.nextLine(sc);

            while (!(ans.equals("y") || ans.equals("Y")
                     || ans.equals("n") || ans.equals("N"))){
                System.out.println("Choose 'y' or 'n'.");
                System.out.print("Do you want to continue? [y/n] ");
                ans = this.nextLine(sc);
            }

            if (ans.equals("y") || ans.equals("Y")){
                return false;
            }else{
                System.out.println("End.");
                return true;
            }
        }else{
            return false;
        }
    }

    /**
     * Load Network class's parameters.
     * @param loaded Loaded Network instance.
     */
    private void loadParameters(Network loaded){
        if (loaded.layers_num != this.layers.length
        || this.layers_num != loaded.layers.length){
            System.out.println("Layer's number error.");
            System.exit(-1);
        }else if (loaded.input_num != this.input_num){
            System.out.println("Input number error.");
            System.exit(-1);
        }
        this.input_num = loaded.input_num;
        this.layers_num = loaded.layers_num;
        this.layers = loaded.layers;
    }

    /**
     * Read next line.
     * If read String is Null, return "".
     * @param sc Scanner instance.
     * @return Readed String instance.
     */
    private String nextLine(Scanner sc) throws NoSuchElementException {
        String ans;
        try {
            ans = sc.nextLine();
        }catch (NoSuchElementException e){
            ans = "";
        }

        return ans;
    }

一度ネットワーククラスのversionフィールドを変数にし、わざとバージョンを変えてテストしたところうまく動きました。

Test.java

前回作成したものに似たモデルを作成し、保存、ロードしてみます。

Test.java
import org.MyNet.optimizer.*;
import org.MyNet.network.*;
import org.MyNet.layer.*;
import org.MyNet.nodes.activationFunction.*;
import org.MyNet.costFunction.*;
import org.MyNet.matrix.*;

public class Test {
    public static void main(String str[]){
        // 学習用データ
        Matrix X = new Matrix(new double[10][2]);
        Matrix T = new Matrix(new double[10][1]);
        for (int i = 0; i < X.row; i++){
            X.matrix[i][0] = i / 3;
            X.matrix[i][1] = i / 2;
            T.matrix[i][0] = X.matrix[i][0] + X.matrix[i][1];
        }

        // ネットワークと最適化関数
        Network net = new Network(
            2,
            new Input(4, AF.RELU),
            new Output(1, AF.LINER)
        );
        SGD sgd = new SGD(net, new MeanSquaredError());

        // 学習と、学習済みモデルの保存
        sgd.fit(X, T, 5, 2);
        net.save("sgd.net");
        System.out.println();

        // 学習済みモデルをロードし、再び学習させる
        net.version = "1.0.1";  // わざとバージョンを変える
        net.load("sgd.net");
        Adam adam = new Adam(net, new MeanSquaredError());
        adam.fit(X, T, 5, 2);
        net.save("adam.net");
    }
}

実行結果はこちら。

Epoch 1/5
loss: 0.0598
Epoch 2/5
loss: 0.0384
Epoch 3/5
loss: 0.0340
Epoch 4/5
loss: 0.0307
Epoch 5/5
loss: 0.0279

The versions of the two Network classes do not match.
The version of loaded is 1.0.0
The version of this is 1.0.1
There is a risk of serious error.
Do you want to continue? [y/n] y
Epoch 1/5
loss: 0.2645
Epoch 2/5
loss: 0.2053
Epoch 3/5
loss: 0.1542
Epoch 4/5
loss: 0.1144
Epoch 5/5
loss: 0.0846

フルバージョン

次回は

これで、一通り深層学習ライブラリとしての機能は実装し終えました。
このライブラリの良い点としては、

  • 最低限の記述で深層学習ができる
  • ネットワークの設計変更が容易
  • Matrixという、共通するデータ型を全体で使っている
  • 製作記が読めるので、内容の理解が容易
  • 関数などの拡張が容易

などが挙げられると思います。
逆に、反省点として、

  • ノードクラスは必要なかったのではないか
  • 倍精度実数を使う必要はなかったのではないか
  • 二次元行列を一次元配列で表現すれば探索や計算がもっと早くできたのではないか
  • 行列同士の計算はstrassenアルゴリズムを使った方が速かったのではないか
  • 全体的に効率が悪そう
  • Matrixクラスのメソッドが、返り値があったりなかったりで分かりづらい

などが挙げられますが、あまり凝り始めると本来の目標を見失いそうなので、次回からはまたオセロに戻ろうと思います。
データの前処理や分割などの機能は、必要になった時にまた作ろうと思います。

次回

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0