概要
RNN(LSTM) の使い方がどうしてもよくわからなかったので、ごく小さな toy program を書いてみて、使い方を確かめてみた。
コードの概要
RNNの使い方の基本は、時系列データを与えて、次に来るデータを予測する、というものだ。
ここでは究極単純に、
1,2,3,4
という数列を与えたら、次に
5
が来ると予測するようにモデルを学習させてみる。
学習
import torch
import numpy as np
from itertools import chain
dict_size = 10
depth = 3
hidden_size = 6
# モデル定義
embedding = torch.nn.Embedding(dict_size, depth)
lstm = torch.nn.LSTM(input_size=depth,
hidden_size=hidden_size,
batch_first=True)
linear = torch.nn.Linear(hidden_size, dict_size)
criterion = torch.nn.CrossEntropyLoss()
params = chain.from_iterable([
embedding.parameters(),
lstm.parameters(),
linear.parameters(),
criterion.parameters()
])
optimizer = torch.optim.SGD(params, lr=0.01)
# 訓練用データ
x = [[1,2, 3, 4]]
y = [5]
# 学習
for i in range(100):
tensor_y = torch.tensor(y)
input_ = torch.tensor(x)
tensor = embedding(input_)
output, (tensor, c_n) = lstm(tensor)
tensor = tensor[0]
tensor = linear(tensor)
loss = criterion(tensor, tensor_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print("{}: {}".format(i + 1, loss.data.item()))
実行すると、確かに loss が減少していくのが観察されるはずだ。
さすがにここまでモデルが単純だと LSTM でも爆速で実行できる。
10: 1.994105339050293
20: 1.893094778060913
30: 1.7957065105438232
40: 1.7019000053405762
50: 1.6116480827331543
60: 1.5249345302581787
70: 1.4417483806610107
80: 1.362081527709961
90: 1.2859251499176025
100: 1.2132656574249268
LSTM モデルの戻り値は怪しげである。
output, (tensor, c_n) = lstm(tensor)
tensor = tensor[0]
上の tensor の位置に、時系列の最後のセルの隠れ状態(hidden state)が入ることになる。
output, c_n に何が入るか、どうして、 tensor[0]
するのか、は、 PyTorchの公式ドキュメントを確認されたい。
推論
上のコードに続けて実行する。
(Jupyter Notebook で実行すると良いかも)
tensor_y = torch.tensor(y)
input_ = torch.tensor(x)
tensor = embedding(input_)
output, (tensor, c_n) = lstm(tensor)
tensor = tensor[0]
tensor = linear(tensor)
print(tensor)
print(torch.argmax(tensor))
実行すると、
tensor([[ 0.0875, -0.4316, 0.0791, 0.0485, -0.1336, 1.5148, -0.4815, 0.1055,
-0.2639, -0.2552]], grad_fn=<AddmmBackward>)
tensor(5)
となり、きちんと 5
と予測できている。素晴らしい!
感想
Toy program の実行を通じて、PyTorch で LSTM をどうやって使えるか感触がつかめた。上の for ループの内側で同じデータを投入し続けているが、実際には毎回異なるデータをミニバッチとして入れることになる。自然言語処理では、登場する語彙に対して単語番号を付け、文章を単語番号のリストとして表現するが、上の [1, 2, 3, 4]
等がそれに当たると考えてもらえればよい。