LoginSignup
7
9

More than 5 years have passed since last update.

PyTorch + LSTM で Hello World 的プログラムを実行

Last updated at Posted at 2019-03-30

概要

RNN(LSTM) の使い方がどうしてもよくわからなかったので、ごく小さな toy program を書いてみて、使い方を確かめてみた。

コードの概要

RNNの使い方の基本は、時系列データを与えて、次に来るデータを予測する、というものだ。
ここでは究極単純に、

1,2,3,4

という数列を与えたら、次に

5

が来ると予測するようにモデルを学習させてみる。

学習

import torch
import numpy as np
from itertools import chain

dict_size = 10
depth = 3
hidden_size = 6

# モデル定義
embedding = torch.nn.Embedding(dict_size, depth)
lstm = torch.nn.LSTM(input_size=depth,
                            hidden_size=hidden_size,
                            batch_first=True)
linear = torch.nn.Linear(hidden_size, dict_size)
criterion = torch.nn.CrossEntropyLoss()
params = chain.from_iterable([
    embedding.parameters(),
    lstm.parameters(),
    linear.parameters(),
    criterion.parameters()
])
optimizer = torch.optim.SGD(params, lr=0.01)

# 訓練用データ
x = [[1,2, 3, 4]]
y = [5]

# 学習
for i in range(100):
    tensor_y = torch.tensor(y)
    input_ = torch.tensor(x)
    tensor = embedding(input_)
    output, (tensor, c_n) = lstm(tensor)
    tensor = tensor[0]
    tensor = linear(tensor)
    loss = criterion(tensor, tensor_y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i + 1) % 10 == 0:
        print("{}: {}".format(i + 1, loss.data.item()))

実行すると、確かに loss が減少していくのが観察されるはずだ。
さすがにここまでモデルが単純だと LSTM でも爆速で実行できる。

10: 1.994105339050293
20: 1.893094778060913
30: 1.7957065105438232
40: 1.7019000053405762
50: 1.6116480827331543
60: 1.5249345302581787
70: 1.4417483806610107
80: 1.362081527709961
90: 1.2859251499176025
100: 1.2132656574249268

LSTM モデルの戻り値は怪しげである。

    output, (tensor, c_n) = lstm(tensor)
    tensor = tensor[0]

上の tensor の位置に、時系列の最後のセルの隠れ状態(hidden state)が入ることになる。
output, c_n に何が入るか、どうして、 tensor[0] するのか、は、 PyTorchの公式ドキュメントを確認されたい。

推論

上のコードに続けて実行する。
(Jupyter Notebook で実行すると良いかも)

tensor_y = torch.tensor(y)
input_ = torch.tensor(x)
tensor = embedding(input_)
output, (tensor, c_n) = lstm(tensor)
tensor = tensor[0]
tensor = linear(tensor)
print(tensor)
print(torch.argmax(tensor))

実行すると、

tensor([[ 0.0875, -0.4316,  0.0791,  0.0485, -0.1336,  1.5148, -0.4815,  0.1055,
         -0.2639, -0.2552]], grad_fn=<AddmmBackward>)
tensor(5)

となり、きちんと 5 と予測できている。素晴らしい!

感想

Toy program の実行を通じて、PyTorch で LSTM をどうやって使えるか感触がつかめた。上の for ループの内側で同じデータを投入し続けているが、実際には毎回異なるデータをミニバッチとして入れることになる。自然言語処理では、登場する語彙に対して単語番号を付け、文章を単語番号のリストとして表現するが、上の [1, 2, 3, 4] 等がそれに当たると考えてもらえればよい。

7
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
9