Edited at

ゼロから作るDeep Learning Java編 6.1 パラメータの更新

More than 1 year has passed since last update.


目次


6.1.2 SGD

まずOptimizerインタフェースを定義します。

public interface Optimizer {

void update(Params params, Params args);

}

SGDの実装は以下のようになります。

public class SGD implements Optimizer {

/** learning rate (学習係数) */
final double lr;

public SGD(double lr) {
this.lr = lr;
}

public SGD() {
this(0.01);
}

@Override
public void update(Params params, Params grads) {
params.update((p, g) -> p.subi(g.mul(lr)), grads);
}
}


6.1.4 Momentum

public class Momentum implements Optimizer {

final double lr, momentum;
Params v;

public Momentum(double lr, double momentum) {
this.lr = lr;
this.momentum = momentum;
this.v = null;
}

public Momentum(double lr) {
this(lr, 0.9);
}

public Momentum() {
this(0.01);
}

@Override
public void update(Params params, Params grads) {
if (v == null)
v = Params.zerosLike(params);
v.update((v, g) -> v.muli(momentum), grads);
v.update((v, g) -> v.subi(g.mul(lr)), grads);
params.update((p, v) -> p.addi(v), v);
}
}


6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

final double lr;
Params h;

public AdaGrad(double lr) {
this.lr = lr;
}

public AdaGrad() {
this(0.01);
}

@Override
public void update(Params params, Params grads) {
if (h == null)
h = Params.zerosLike(params);
h.update((h, g) -> h.addi(g.mul(g)), grads);
params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
}
}


6.1.6 Adam

public class Adam implements Optimizer {

final double lr, beta1, beta2;
int iter;
Params m, v;

public Adam(double lr, double beta1, double beta2) {
this.lr = lr;
this.beta1 = beta1;
this.beta2 = beta2;
this.iter = 0;
}

public Adam(double lr) {
this(lr, 0.9, 0.999);
}

public Adam() {
this(0.001);
}

@Override
public void update(Params params, Params grads) {
if (m == null) {
m = Params.zerosLike(params);
v = Params.zerosLike(params);
}
++iter;
double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
}

}


6.1.7 どの更新手法を用いるか?

グラフを作成するために簡単なクラスGraphImageを作成しました。

// ch06/optimizer_compare_naive.py の java版です。

// GraphImageを使ってグラフを作成します。
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
// 初期値の(0, 0)からの距離です。
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
.put("x", Nd4j.create(new double[] {init_pos[0]}))
.put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
.put("x", Nd4j.create(new double[] {0}))
.put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
Optimizer optimizer = optimizers.get(key);
params.put("x", Nd4j.create(new double[] {init_pos[0]}))
.put("y", Nd4j.create(new double[] {init_pos[1]}));
double min_distance = Double.MAX_VALUE;
double last_distance = 0.0;
double prevX = init_pos[0];
double prevY = init_pos[1];
try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
// グラフのタイトルを描画します。
image.text(key, -2, 7);
// 最初の点をプロットします。
image.plot(prevX, prevY);
for (int i = 0; i < 30; ++i) {
INDArray temp = df.apply(params.get("x"), params.get("y"));
grads.put("x", temp.getColumn(0));
grads.put("y", temp.getColumn(1));
optimizer.update(params, grads);
double x = params.get("x").getDouble(0);
double y = params.get("y").getDouble(0);
last_distance = Math.hypot(x, y);
if (last_distance < min_distance)
min_distance = last_distance;
// 直前の点から線を引きます。
image.line(prevX, prevY, x, y);
// 値をプロットします。
image.plot(x, y);
prevX = x;
prevY = y;
}
// 初期値よりも最適化されていることを確認します。
assertTrue(last_distance < init_distance);
assertTrue(min_distance < init_distance);
// グラフをファイル出力します。
image.writeTo(new File(outdir, key + ".png"));
}
}

結果のグラフは以下のようになりました。

SGD
Momentum
AdaGrad
Adam

SGD.png
Momentum.png
AdaGrad.png
Adam.png


6.1.8 MNISTデータセットによる更新手法の比較

// ch06/optimizer_compare_mnist.py の Java版です。

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.実験の設定
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
networks.put(key, new MultiLayerNet(
784, new int[] {100, 100, 100, 100}, 10));
train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);

// 2.訓練の開始
for (int i = 0; i < max_iterations; ++i) {
// バッチデータを抽出します。
DataSet sample = dataset.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
for (String key : optimizers.keySet()) {
MultiLayerNet network = networks.get(key);
Params grads = network.gradicent(x_batch, t_batch);
optimizers.get(key).update(network.params, grads);
double loss = network.loss(x_batch, t_batch);
train_loss.get(key).add(loss);
}
if (i % 100 == 0) {
System.out.println("===========" + "iteration:" + i + "===========");
for (String key : optimizers.keySet()) {
double loss = networks.get(key).loss(x_batch, t_batch);
System.out.println(key + ":" + loss);
}
}
}

// 3.グラフの描画
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
Map<String, Color> colors = new HashMap<>();
colors.put("SGD", Color.GREEN);
colors.put("Momentum", Color.BLUE);
colors.put("AdaGrad", Color.RED);
colors.put("Adam", Color.ORANGE);
double w = 1300;
double h = 0.7;
for (String key : train_loss.keySet()) {
List<Double> loss = train_loss.get(key);
graph.color(colors.get(key));
graph.text(key, w, h);
h += 0.05;
graph.plot(0, loss.get(0));
int step = 10;
for (int i = step, size = loss.size(); i < size; i += step) {
graph.line(i - step, loss.get(i - step), i, loss.get(i));
graph.plot(i, loss.get(i));
}
}
graph.color(Color.BLACK);
graph.text("横=繰り返し回数(0,2000) 縦=損失関数の値(0,1)", w, h);
h += 0.05;
graph.text("MNISTデータセットに対する4つの更新手法の比較", w, h);
if (!Constants.OptimizerImages.exists())
Constants.OptimizerImages.mkdirs();
graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png"));
}

結果のグラフは以下のようになりました。

compare_mnist.png