3
4

More than 1 year has passed since last update.

Pytorchで新しくRNNを作るには?

Last updated at Posted at 2020-08-30

研究でRNNを使う時に,デフォルトのtorch.nn.RNN()等を使うのではなくて,
自作したRNNを使用したいと思っていたのですが,やり方に結構悩んだので,備忘録的にまとめておきます.

今回扱うのはデフォルトのRNNCellやLSTM,GRUなどを使ってモデルを組む段階の話ではなく,
RNNCellそのものを自作したり,内部のリンクを変更したいと思った時にどうすればいいかというお話です.

結論から言ってしまえば,nn.Moduleを継承してRNNのクラスを作るだけでした.
※ゼロから作るDeep Learning 3 のp.475~477がとても参考になりました.
部分的なコードは以下のような感じです.

import torch
import torch.nn as nn

class RNNCell(nn.Module):
  def __init__(self,n_in,n_hid):
    super(RNNCell,self).__init__()
    self.i2h=nn.Linear(n_in,n_hid)
    self.h2h=nn.Linear(n_hid,n_hid)
    self.h=None

  def reset_state(self):
    self.h=None

  def forward(self,x):
    if self.h is None:
      h_new=self.i2h(x)
    else:
      h_new=self.i2h(x)+self.h2h(torch.tanh(self.h))
    self.h=h_new
    h_out=torch.tanh(h_new)
    h_out=h_out.detach()

    return h_out

ちなみに上のコードのh_newは隠れ状態ではありません.
その一歩手前の(tanhに通す前の)内部状態を表しています.
また,outputのリンク部分は実際に回すモデルの方(RNNCellを使ってモデルを作る時に)で指定してます.

後は好きなように
h_new=self.i2h(x)+self.h2h(torch.tanh(self.h))
の部分などを変更すれば,任意のRNNが作れるはずです.

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4