10
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorchでRNNを使う簡単な例

Last updated at Posted at 2018-12-05

TL; DR

  • Web上だと画像系のサンプルコードは多いけど,イマイチ系列系のサンプルコードが難しかったので簡単な例を作った
  • PyTorchのRNNを利用して,系列を受け取り系列の和を予測するモデルを作成した
  • PyTorchのversionは0.4.1
  • notebook: http://nbviewer.jupyter.org/gist/cocomoff/e0d415d0d89e793b0f84a3a13f47e604

リンクなど

データ

  • 例:
    • 入力 [1,0,1,2,0,0,0,0,0]
    • 出力 4
  • ある1つの系列(input_dimは適当に固定)から,ある1つの数値を予測するタスク

予測器

  • RNNを利用し,RNNの隠れ層から値を予測する層をnn.Linearで付与する
  • batch_firstをTrueにすると,**(seq_len, batch, input_size)と指定されている入力テンソルの型を(batch, seq_len, input_size)**にできる
  • ある1つの列が対象なので,seq_len=1とする
    • そのためテンソルの型をunsqueeze/squeezeでマニュアル通りに揃える
    • (補足): 実はseq_lenとinput_dimの使い方が間違っている気もする(正しい使い方が分かる人教えて)
  • 出力outputは**(batch, seq_len, num_directions * hidden_size)と指定されているが,双方向ではないためnum_directions=1**となる
  • 出力hpは**(num_layers * num_directions, batch, hidden_size)と指定されているが,層を積まないのかつ単方向なため,(1, batch, hidden_size)**という型で返る
class Predictor(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Predictor, self).__init__()
        self.rnn = nn.RNN(n_input, n_hidden, num_layers=1, batch_first=True)
        self.out = nn.Linear(n_hidden, n_output)
        
    def forward(self, x, h=None):
        output, hp = self.rnn(x.unsqueeze(1), h)
        output = self.out(output.squeeze(1))
        return output, hp

結果

  • 誤差をプロットしてみる
    • RNNはLSTMに比較して長い列が苦手だとよく言われている(教科書などを参考)
    • なので固定長100ぐらいまで長くしてあるので,RNNをLSTMに置き換えると,LSTMの方が誤差が小さく精度高く和を出力する傾向が見られる

loss.png

10
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?