コード
lstm_updater.py
# -*- coding: utf-8 -*-
import chainer
from chainer import training
class LSTMUpdater(training.StandardUpdater):
def __init__(self, data_iter, optimizer, device=None):
super(LSTMUpdater, self).__init__(data_iter, optimizer, device=None)
self.device = device
def update_core(self):
data_iter = self.get_iterator('main')
optimizer = self.get_optimizer('main')
batch = data_iter.__next__()
x_batch, y_batch = chainer.dataset.concat_examples(batch, self.device)
optimizer.target.reset_state()
optimizer.target.cleargrads()
loss = optimizer.target(x_batch, y_batch)
loss.backward()
loss.unchain_backward()
optimizer.update()
読み出し
from updater import LSTMUpdater
from chainer import training
updater = LSTMUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (epoch, 'epoch'), out='result')