Help us understand the problem. What is going on with this article?

ゼロから作るDeep Learning Java編 6.4 正則化

More than 1 year has passed since last update.

目次

6.4.1 過学習

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (荷重減衰)の設定 ===========
double weight_decay_lambda = 0; // weight decayを使用しない場合
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("過学習における認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png"));

overfit.png

6.4.2 Weidht decay

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (荷重減衰)の設定 ===========
// weight_decay_lambda = 0 // weight decayを使用しない場合
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Weight decayを用いた過学習における認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png"));

overfit_weight_decay.png

6.4.3 Dropout

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
// 過学習を再現するために、学習データを削減
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
// Dropoutの有無、割合の設定====================
boolean use_dropout = true; // Dropoutなしの時はfalseに
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu",
    /*weight_init_std*/"relu",
    /*weight_decay_lambda*/0,
    /*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
    /*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
    /*epochs*/301,
    /*mini_batch_size*/100,
    /*optimizer*/() -> new SGD(0.01),
    /*evaluate_sample_num_per_epoch*/0,
    /*verbose*/true);

trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Dropoutにおける認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png"));

dropout.png

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした