ゼロから作るDeep Learning Java編 6.1 パラメータの更新

Last updated at Posted at 2018-03-19


6.1.2 SGD


public interface Optimizer {

    void update(Params params, Params args);



public class SGD implements Optimizer {

    /** learning rate (学習係数) */
    final double lr;

    public SGD(double lr) {
        this.lr = lr;

    public SGD() {

    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) {
        this.lr = lr;
        this.momentum = momentum;
        this.v = null;

    public Momentum(double lr) {
        this(lr, 0.9);

    public Momentum() {

    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) {
        this.lr = lr;

    public AdaGrad() {

    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) {
        this.lr = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;

    public Adam(double lr) {
        this(lr, 0.9, 0.999);

    public Adam() {

    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);


6.1.7 どの更新手法を用いるか?


// ch06/optimizer_compare_naive.py の java版です。
// GraphImageを使ってグラフを作成します。
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
// 初期値の(0, 0)からの距離です。
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        // グラフのタイトルを描画します。
        image.text(key, -2, 7);
        // 最初の点をプロットします。
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            // 直前の点から線を引きます。
            image.line(prevX, prevY, x, y);
            // 値をプロットします。
            image.plot(x, y);
            prevX = x;
            prevY = y;
        // 初期値よりも最適化されていることを確認します。
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        // グラフをファイル出力します。
        image.writeTo(new File(outdir, key + ".png"));


SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 MNISTデータセットによる更新手法の比較

// ch06/optimizer_compare_mnist.py の Java版です。
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.実験の設定
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
DataSet dataset = new DataSet(x_train, t_train);

// 2.訓練の開始
for (int i = 0; i < max_iterations; ++i) {
    // バッチデータを抽出します。
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);

// 3.グラフの描画
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
    graph.text("横=繰り返し回数(0,2000) 縦=損失関数の値(0,1)", w, h);
    h += 0.05;
    graph.text("MNISTデータセットに対する4つの更新手法の比較", w, h);
    if (!Constants.OptimizerImages.exists())
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png"));



