0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ミニバッチ学習

Last updated at Posted at 2023-11-27

バッチ学習

学習用データを一度にすべて入力し、重みを更新する学習方法

ミニバッチ学習

学習用データを小さなバッチに分けて、その小さなバッチ毎に重みを更新していく学習方法。

特徴
・バッチサイズが大きくなると勾配推定が正確になる

バッチ学習はデータが大きくなると計算に時間がかかるため、あまり使われない。

非復元抽出を理解するために復元抽出をやってみた。
np.random.seed(1234) #seedは乱数を固定
train_size = train_labels.shape[0]
batch_size = 32
max_iter = 10  #ループの回数
network = None #ダミー

for i in range(max_iter):
    batch_mask = np.random.choice(train_size, batch_size) #random.choiceで復元抽出
    print("i=%s, "%i, "batch_mask=%s"%batch_mask[:10])
    x_batch = train[batch_mask]
    y_batch = train_labels[batch_mask]

    trainer(network, x_batch, y_batch)

非復元抽出

np.random.seed(1234)
train_size = train_labels.shape[0]
batch_size = 32
epochs = 10
network = None #ダミー
minibatch_num = np.ceil(train_size/batch_size).astype(int) # ミニバッチの個数
    
for epoch in range(epochs):
    
    # indexを定義し、シャッフルする
    index = np.arange(train_size)
    np.random.shuffle(index)
    
    for mn in range(minibatch_num):
        """
        非復元抽出によるループ
        """
        batch_mask = index[batch_size*mn:batch_size*(mn+1)]        
        print("epoch=%s, "%epoch, "batch_mask=%s"%batch_mask[:10])
        x_batch = train[batch_mask]
        y_batch = train_labels[batch_mask]

        trainer(network, x_batch, y_batch)
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?