0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

駆け出しエンジニアの機械学習メモ その2

Posted at

#はじめに
「ゼロから作るDeep-Learning」の学習メモその2です。

#2層ニューラルネットワーク(4章)

  • mnist.pyよりデータの読み込みを行う。正規化、one_hot配列化、データの1次元配列化を行う。
train_neuralnet
# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True,flatten=True, one_hot_label=True)
  • two_layer_net.pyより重みの初期化を行う。辞書型でparams{W1:,b1:,W2:,b2:}のkeyを生成。
    print(network.params)で中身を見ることができる。
train_neuralnet
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
  • 各初期値の設定、60000枚の画像データを100枚ずつ処理していく。
train_neuralnet
iters_num = 10000  # 繰り返しの回数を適宜設定する
train_size = x_train.shape[0] # 60000
batch_size = 100
learning_rate = 0.1

train_loss_list = []
train_acc_list = []
test_acc_list = []
# 1エポックあたりの繰り返し処理 60000 / 100
iter_per_epoch = max(train_size / batch_size, 1)
  • 訓練データから無造作に一部のデータを取り出し学習を行う。ここでは60000枚のデータから100枚取り出す。
train_neuralnet
for i in range(iters_num): #10000
    # ミニバッチの取得
    batch_mask = np.random.choice(train_size, batch_size) # (100,)の形
    x_batch = x_train[batch_mask] # (100,784)の形
    t_batch = t_train[batch_mask] # (100,784)の形
    
    # 勾配の計算
    #grad = network.numerical_gradient(x_batch, t_batch)
    grad = network.gradient(x_batch, t_batch)
    
    # パラメータの更新
    for key in ('W1', 'b1', 'W2', 'b2'):
        network.params[key] -= learning_rate * grad[key]
    
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)

    # 600で1エポックのため条件を満たしたらデータの保存
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))
  • リストに保存していたデータを利用してグラフの描画を行う。
# グラフの描画
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list,'o', label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

* 描画結果
image.png

参考

ゼロから作るDeep Learning

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?