1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ゼロから作るDeep Learning Java編 6.1 パラメータの更新

Last updated at Posted at 2018-03-19

目次

6.1.2 SGD

まずOptimizerインタフェースを定義します。

public interface Optimizer {

    void update(Params params, Params args);

}

SGDの実装は以下のようになります。

public class SGD implements Optimizer {

    /** learning rate (学習係数) */
    final double lr;

    public SGD(double lr) {
        this.lr = lr;
    }

    public SGD() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);
    }
}

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) {
        this.lr = lr;
        this.momentum = momentum;
        this.v = null;
    }

    public Momentum(double lr) {
        this(lr, 0.9);
    }

    public Momentum() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);
    }
}

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) {
        this.lr = lr;
    }

    public AdaGrad() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
    }
}

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) {
        this.lr = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;
    }

    public Adam(double lr) {
        this(lr, 0.9, 0.999);
    }

    public Adam() {
        this(0.001);
    }

    @Override
    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        }
        ++iter;
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
    }

}

6.1.7 どの更新手法を用いるか?

グラフを作成するために簡単なクラスGraphImageを作成しました。

// ch06/optimizer_compare_naive.py の java版です。
// GraphImageを使ってグラフを作成します。
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
// 初期値の(0, 0)からの距離です。
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        // グラフのタイトルを描画します。
        image.text(key, -2, 7);
        // 最初の点をプロットします。
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            // 直前の点から線を引きます。
            image.line(prevX, prevY, x, y);
            // 値をプロットします。
            image.plot(x, y);
            prevX = x;
            prevY = y;
        }
        // 初期値よりも最適化されていることを確認します。
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        // グラフをファイル出力します。
        image.writeTo(new File(outdir, key + ".png"));
    }
}

結果のグラフは以下のようになりました。

SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 MNISTデータセットによる更新手法の比較

// ch06/optimizer_compare_mnist.py の Java版です。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.実験の設定
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);

// 2.訓練の開始
for (int i = 0; i < max_iterations; ++i) {
    // バッチデータを抽出します。
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3.グラフの描画
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.color(colors.get(key));
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
        }
    }
    graph.color(Color.BLACK);
    graph.text("横=繰り返し回数(0,2000) 縦=損失関数の値(0,1)", w, h);
    h += 0.05;
    graph.text("MNISTデータセットに対する4つの更新手法の比較", w, h);
    if (!Constants.OptimizerImages.exists())
        Constants.OptimizerImages.mkdirs();
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png"));
}

結果のグラフは以下のようになりました。
compare_mnist.png

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?