目次

6.4.1 過学習

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (荷重減衰)の設定 ===========
double weight_decay_lambda = 0; // weight decayを使用しない場合
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("過学習における認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png"));

overfit.png

6.4.2 Weidht decay

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (荷重減衰)の設定 ===========
// weight_decay_lambda = 0 // weight decayを使用しない場合
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Weight decayを用いた過学習における認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png"));

overfit_weight_decay.png

6.4.3 Dropout

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
// 過学習を再現するために、学習データを削減
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
// Dropoutの有無、割合の設定====================
boolean use_dropout = true; // Dropoutなしの時はfalseに
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu",
    /*weight_init_std*/"relu",
    /*weight_decay_lambda*/0,
    /*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
    /*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
    /*epochs*/301,
    /*mini_batch_size*/100,
    /*optimizer*/() -> new SGD(0.01),
    /*evaluate_sample_num_per_epoch*/0,
    /*verbose*/true);

trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;

// 3.グラフの描画=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Dropoutにおける認識精度", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png"));

dropout.png

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.