Help us understand the problem. What is going on with this article?

ゼロから作るDeep Learning Java編 6.1 パラメータの更新

More than 1 year has passed since last update.

目次

6.1.2 SGD

まずOptimizerインタフェースを定義します。

public interface Optimizer {

    void update(Params params, Params args);

}

SGDの実装は以下のようになります。

public class SGD implements Optimizer {

    /** learning rate (学習係数) */
    final double lr;

    public SGD(double lr) {
        this.lr = lr;
    }

    public SGD() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);
    }
}

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) {
        this.lr = lr;
        this.momentum = momentum;
        this.v = null;
    }

    public Momentum(double lr) {
        this(lr, 0.9);
    }

    public Momentum() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);
    }
}

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) {
        this.lr = lr;
    }

    public AdaGrad() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
    }
}

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) {
        this.lr = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;
    }

    public Adam(double lr) {
        this(lr, 0.9, 0.999);
    }

    public Adam() {
        this(0.001);
    }

    @Override
    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        }
        ++iter;
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
    }

}

6.1.7 どの更新手法を用いるか?

グラフを作成するために簡単なクラスGraphImageを作成しました。

// ch06/optimizer_compare_naive.py の java版です。
// GraphImageを使ってグラフを作成します。
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
// 初期値の(0, 0)からの距離です。
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        // グラフのタイトルを描画します。
        image.text(key, -2, 7);
        // 最初の点をプロットします。
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            // 直前の点から線を引きます。
            image.line(prevX, prevY, x, y);
            // 値をプロットします。
            image.plot(x, y);
            prevX = x;
            prevY = y;
        }
        // 初期値よりも最適化されていることを確認します。
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        // グラフをファイル出力します。
        image.writeTo(new File(outdir, key + ".png"));
    }
}

結果のグラフは以下のようになりました。

SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 MNISTデータセットによる更新手法の比較

// ch06/optimizer_compare_mnist.py の Java版です。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.実験の設定
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);

// 2.訓練の開始
for (int i = 0; i < max_iterations; ++i) {
    // バッチデータを抽出します。
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3.グラフの描画
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.color(colors.get(key));
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
        }
    }
    graph.color(Color.BLACK);
    graph.text("横=繰り返し回数(0,2000) 縦=損失関数の値(0,1)", w, h);
    h += 0.05;
    graph.text("MNISTデータセットに対する4つの更新手法の比較", w, h);
    if (!Constants.OptimizerImages.exists())
        Constants.OptimizerImages.mkdirs();
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png"));
}

結果のグラフは以下のようになりました。
compare_mnist.png

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away