LoginSignup
8
7

More than 1 year has passed since last update.

はじめに

ボンジュール、マドモアゼル&ジェントルメン 💩
サッカーワールドカップが終わった今、興味があることと言ったらLSTM(Long Short Term Memory)とGRU(Gated Recurrent Unit)の性能比較くらいですよね。
 というわけで、そんなマドモアゼルとジェントルメンのために、今日はRNN(Recurrent Neural Network)の一種であるLSTMとGRUを、自然言語処理性能で比較してみましょう。
 え、今更?transformerでいいんじゃないの?と思われたあなた、ギョギョギョ、おっしゃるとおりでギョザイます。
まあ今どきGRUとか使うのは、ちょっと小一時間で開発してテストしたいときや勉強のためくらいでしょうかね。
なので今日はそんなノリで行きます。

結論

お急ぎの方のために、結論を先に書きます。
今回のテキスト解析のタスクでは、学習速度も分類性能も大差ありませんでした。
というわけで、小規模なデータでちょっとテストしたいときだけの時には、よりシンプルな構造のGRUを用いて、より大規模なデータで、もうちょっと性能を出したいけど、transformerを使うほどではないな、というときは、一般的にGRUより性能がいいと言われているLSTMを使えばいいのではないでしょうか。

背景

もともとGRUはLSTMの構造をシンプルにするために提案されたようなもので、元論文でも"we found GRU to be comparable to LSTM"と書いてあり、LSTMを上回るわけではないが、シンプルな構造で同程度の性能といったものです。また、wikipediaのGRUの項目にもあるように、基本的にはLSTM方が性能は上回るようです。
 ではGRUはどれくらいLSTMと同程度の性能なのでしょうか?あまり変わらないのであれば、シンプルでパラメータも少ないGRUを使う理由になると思います。
 そこで今回は、具体的なテキスト解析の事例を使って試してみます。

目的

実際のテキスト解析タスクでの、LSTMとGRUの、学習時間と識別性能の比較です。

LSTMとGRUのアルゴリズムの解説

 LSTMとGRUのアルゴリズムについては、私が新たに解説するよりも、こちらに非常に分かり易い解説があり、絵も豊富にあるので、アルゴリズムを理解したい方は、こちらを読むことをおすすめします。
 ちなみに今回の記事のベースとなった「Mastering PyTorch: Build powerful neural network architectures using advanced PyTorch 1.x features」の4章のLSTMの絵は一部誤りがあるので、その意味でも上記解説記事を読むのが良いでしょう。

比較方法

  • タスク
    • ユーザーが投稿した映画レビューを、positive or negativeの2値にクラス分類するタスクです。
  • 使用データ

評価コード説明

 この記事で説明するコードは、 Ashish Ranjan Jha著の「Mastering PyTorch: Build powerful neural network architectures using advanced PyTorch 1.x features」の4章に基づいています。
 ただし、本家のコードの4章はバグがあり、そのままでは動かなかったのと、GRUについては実装がなかったので、Forkして修正したものをこちらにあげてあります。

ですので、ここではエッセンスだけ挙げておきます。

LSTM

必要なライブラリをインポートした後、データセットを持ってきて、テキストから数字へ変換して、LSTMで取り扱えるようにします。
そして次のようにLSTMのクラスを定義します。ここで、元のコードではlayerが1でwarninngが出ていたので、layerは2に修正しています。

class LSTM(nn.Module):
    def __init__(self, vocabulary_size, embedding_dimension, hidden_dimension, output_dimension, dropout, pad_index):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_dimension, padding_idx = pad_index)
        # num_layers should be > 1. otherwise, it cause below warning
        # UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1
        self.lstm_layer = nn.LSTM(embedding_dimension, 
                           hidden_dimension, 
                           num_layers=2, 
                           bidirectional=True, 
                           dropout=dropout)
        self.fc_layer = nn.Linear(hidden_dimension * 2, output_dimension)
        self.dropout_layer = nn.Dropout(dropout)
        
    def forward(self, sequence, sequence_lengths=None):
        if sequence_lengths is None:
            sequence_lengths = torch.LongTensor([len(sequence)])
        
        # sequence := (sequence_length, batch_size)
        embedded_output = self.dropout_layer(self.embedding_layer(sequence))
        
        
        # embedded_output := (sequence_length, batch_size, embedding_dimension)
        if torch.cuda.is_available():
            packed_embedded_output = cuda_pack_padded_sequence(embedded_output, sequence_lengths)
        else:
            packed_embedded_output = nn.utils.rnn.pack_padded_sequence(embedded_output, sequence_lengths)
        
        packed_output, (hidden_state, cell_state) = self.lstm_layer(packed_embedded_output)
        # hidden_state := (num_layers * num_directions, batch_size, hidden_dimension)
        # num_directions = 2 if bidirectional LSTM.
        # cell_state := (num_layers * num_directions, batch_size, hidden_dimension)
        
        op, op_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
        # op := (sequence_length, batch_size, hidden_dimension * num_directions)
        
        hidden_output = torch.cat((hidden_state[-2,:,:], hidden_state[-1,:,:]), dim = 1)        
        # hidden_output := (batch_size, hidden_dimension * num_directions)
        
        return self.fc_layer(hidden_output)

次に、学習のための関数を定義し、学習をはじめていきます。エポックは10とします。

num_epochs = 10
best_validation_loss = float('inf')

for ep in range(num_epochs):

    time_start = time.time()
    
    training_loss, train_accuracy = train(lstm_model, train_data_iterator, optim, loss_func)
    validation_loss, validation_accuracy = validate(lstm_model, valid_data_iterator, loss_func)
    
    time_end = time.time()
    time_delta = time_end - time_start 
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        torch.save(lstm_model.state_dict(), 'lstm_model.pt')
    
    print(f'epoch number: {ep+1} | time elapsed: {time_delta}s')
    print(f'training loss: {training_loss:.3f} | training accuracy: {train_accuracy*100:.2f}%')
    print(f'validation loss: {validation_loss:.3f} |  validation accuracy: {validation_accuracy*100:.2f}%')
    print()

結果出力は次章に載せました。
次に、この学習結果を使ってテストを行います。

lstm_model.load_state_dict(torch.load('../../Mastering-PyTorch/Chapter04/lstm_model.pt'))

test_loss, test_accuracy = validate(lstm_model, test_data_iterator, loss_func)

print(f'test loss: {test_loss:.3f} | test accuracy: {test_accuracy*100:.2f}%')

こちらも結果は、後でGRUと比較して載せておきました。
では、この学習結果を使って、推論を行える関数を定義しましょう。

def sentiment_inference(model, sentence):
    model.eval()
    
    # text transformations
    tokenized = data.get_tokenizer("basic_english")(sentence)
    tokenized = [TEXT_FIELD.vocab.stoi[t] for t in tokenized]
    
    # model inference
    model_input = torch.LongTensor(tokenized).to(device)
    model_input = model_input.unsqueeze(1)
    
    pred = torch.sigmoid(model(model_input))
    
    return pred.item()

これを使って、独自の入力に対して推論を行います。

GRU

基本的にはLSTMとそれほど変わらないので、クラス定義だけ載せておきます。

class GRU(nn.Module):
    def __init__(self, vocabulary_size, embedding_dimension, hidden_dimension, output_dimension, dropout, pad_index):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_dimension, padding_idx = pad_index)
        # num_layers should be > 1. otherwise, it cause below warning
        # UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1
        self.gru_layer = nn.GRU(embedding_dimension, hidden_dimension, num_layers=2,
dropout=0.8, bidirectional=True)
        self.fc_layer = nn.Linear(hidden_dimension * 2, output_dimension)
        self.dropout_layer = nn.Dropout(dropout)
        
    def forward(self, sequence, sequence_lengths=None):
        if sequence_lengths is None:
            sequence_lengths = torch.LongTensor([len(sequence)])
        
        # sequence := (sequence_length, batch_size)
        embedded_output = self.dropout_layer(self.embedding_layer(sequence))
        
        
        # embedded_output := (sequence_length, batch_size, embedding_dimension)
        if torch.cuda.is_available():
            packed_embedded_output = cuda_pack_padded_sequence(embedded_output, sequence_lengths)
        else:
            packed_embedded_output = nn.utils.rnn.pack_padded_sequence(embedded_output, sequence_lengths)
        
        packed_output, hidden_state = self.gru_layer(packed_embedded_output)
        # hidden_state := (num_layers * num_directions, batch_size, hidden_dimension)
        # num_directions = 2 if bidirectional.
        # cell_state := (num_layers * num_directions, batch_size, hidden_dimension)
        
        op, op_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
        # op := (sequence_length, batch_size, hidden_dimension * num_directions)
        
        hidden_output = torch.cat((hidden_state[-2,:,:], hidden_state[-1,:,:]), dim = 1)        
        # hidden_output := (batch_size, hidden_dimension * num_directions)
        
        return self.fc_layer(hidden_output)

性能比較

学習時間の比較

前章の学習結果は、LSTM,GRUで、それぞれ以下のようになりました。

LSTM
epoch number: 1 | time elapsed: 71.49335956573486s
training loss: 0.682 | training accuracy: 55.14%
validation loss: 0.611 |  validation accuracy: 66.93%

epoch number: 2 | time elapsed: 138.23603439331055s
training loss: 0.610 | training accuracy: 66.95%
validation loss: 0.534 |  validation accuracy: 73.91%

epoch number: 3 | time elapsed: 118.86396598815918s
training loss: 0.566 | training accuracy: 71.13%
validation loss: 0.537 |  validation accuracy: 73.18%

epoch number: 4 | time elapsed: 134.55875992774963s
training loss: 0.609 | training accuracy: 67.31%
validation loss: 0.537 |  validation accuracy: 75.13%

epoch number: 5 | time elapsed: 128.78285789489746s
training loss: 0.535 | training accuracy: 74.46%
validation loss: 0.524 |  validation accuracy: 74.36%

epoch number: 6 | time elapsed: 118.4251012802124s
training loss: 0.454 | training accuracy: 79.86%
validation loss: 0.385 |  validation accuracy: 83.62%

epoch number: 7 | time elapsed: 125.99788117408752s
training loss: 0.409 | training accuracy: 82.08%
validation loss: 0.349 |  validation accuracy: 85.09%

epoch number: 8 | time elapsed: 121.55267071723938s
training loss: 0.364 | training accuracy: 84.51%
validation loss: 0.395 |  validation accuracy: 82.73%

epoch number: 9 | time elapsed: 124.1032452583313s
training loss: 0.335 | training accuracy: 85.94%
validation loss: 0.327 |  validation accuracy: 86.14%

epoch number: 10 | time elapsed: 120.42302107810974s
training loss: 0.310 | training accuracy: 87.20%
validation loss: 0.377 |  validation accuracy: 84.64%
GRU
epoch number: 1 | time elapsed: 120.04946160316467s
training loss: 0.693 | training accuracy: 52.76%
validation loss: 0.678 |  validation accuracy: 56.87%

epoch number: 2 | time elapsed: 20.851502180099487s
training loss: 0.681 | training accuracy: 56.01%
validation loss: 0.646 |  validation accuracy: 62.26%

epoch number: 3 | time elapsed: 6.450242519378662s
training loss: 0.636 | training accuracy: 63.32%
validation loss: 0.615 |  validation accuracy: 64.90%

epoch number: 4 | time elapsed: 6.636687994003296s
training loss: 0.575 | training accuracy: 70.08%
validation loss: 0.523 |  validation accuracy: 74.00%

epoch number: 5 | time elapsed: 6.49686336517334s
training loss: 0.509 | training accuracy: 75.56%
validation loss: 0.432 |  validation accuracy: 80.97%

epoch number: 6 | time elapsed: 7.006083726882935s
training loss: 0.489 | training accuracy: 76.97%
validation loss: 0.490 |  validation accuracy: 77.09%

epoch number: 7 | time elapsed: 6.958749294281006s
training loss: 0.412 | training accuracy: 81.63%
validation loss: 0.548 |  validation accuracy: 78.66%

epoch number: 8 | time elapsed: 7.189222097396851s
training loss: 0.367 | training accuracy: 84.28%
validation loss: 0.396 |  validation accuracy: 84.77%

epoch number: 9 | time elapsed: 7.317926645278931s
training loss: 0.333 | training accuracy: 85.69%
validation loss: 0.413 |  validation accuracy: 85.20%

epoch number: 10 | time elapsed: 7.510709762573242s
training loss: 0.307 | training accuracy: 87.34%
validation loss: 0.365 |  validation accuracy: 87.38%

エポック1こそ時間がLSTMの10エポック目と同程度ですが、2エポック目からかなり早くなり、エポック10で比較すると、LSTMが120secだったのに対してGRUは7sec。圧倒的にGRUの学習時間が短かったです。
最終的なaccuracyも、LSTM、GRUで似たようなものでした。
しかし、なぜこうなったのでしょうか?もうちょっと調査してみます。

学習時間追加評価

GRUでエポック1が他のエポックに比べて異常に時間がかかっていた理由がどうも理解できなかったので、trainとvalで時間を分けて測るようにして、再度学習を行ってみました。結果がこちらです。

GRU
epoch number: 1 | time elapsed: 10.142573118209839s
train time = 9.305720090866089, validation time = 0.8368189334869385
training loss: 0.693 | training accuracy: 52.23%
validation loss: 0.677 |  validation accuracy: 57.58%

epoch number: 2 | time elapsed: 19.76072382926941s
train time = 17.982912063598633, validation time = 1.7778115272521973
training loss: 0.681 | training accuracy: 55.93%
validation loss: 0.641 |  validation accuracy: 63.58%

epoch number: 3 | time elapsed: 20.421406507492065s
train time = 17.444645166397095, validation time = 2.9767611026763916
training loss: 0.631 | training accuracy: 64.18%
validation loss: 0.644 |  validation accuracy: 69.56%

epoch number: 4 | time elapsed: 23.67748498916626s
train time = 20.439718008041382, validation time = 3.237766742706299
training loss: 0.588 | training accuracy: 69.00%
validation loss: 0.521 |  validation accuracy: 76.07%

epoch number: 5 | time elapsed: 20.615106344223022s
train time = 18.653815031051636, validation time = 1.9612908363342285
training loss: 0.523 | training accuracy: 74.42%
validation loss: 0.484 |  validation accuracy: 77.92%

epoch number: 6 | time elapsed: 19.7037615776062s
train time = 17.35865044593811, validation time = 2.3451108932495117
training loss: 0.471 | training accuracy: 77.90%
validation loss: 0.567 |  validation accuracy: 76.21%

epoch number: 7 | time elapsed: 20.183289527893066s
train time = 17.201263189315796, validation time = 2.9820258617401123
training loss: 0.410 | training accuracy: 81.78%
validation loss: 0.430 |  validation accuracy: 82.45%

epoch number: 8 | time elapsed: 20.243817806243896s
train time = 17.51808476448059, validation time = 2.7257325649261475
training loss: 0.365 | training accuracy: 84.34%
validation loss: 0.434 |  validation accuracy: 83.85%

epoch number: 9 | time elapsed: 20.38038659095764s
train time = 17.858877897262573, validation time = 2.52150821685791
training loss: 0.329 | training accuracy: 86.10%
validation loss: 0.359 |  validation accuracy: 87.13%

epoch number: 10 | time elapsed: 20.606106519699097s
train time = 18.336904287338257, validation time = 2.2692017555236816
training loss: 0.306 | training accuracy: 87.43%
validation loss: 0.425 |  validation accuracy: 85.91%

大体1エポック20秒くらいに落ち着き、最初だけ遅いということはありませんでした。
前回はLSTM→ GRUの順に学習を行ったのですが、今度はGRU → LSTMの順に行ってみました。
結果がこちら。

LSTM
epoch number: 1 | time elapsed: 10.142573118209839s
train time = 9.305720090866089, validation time = 0.8368189334869385
training loss: 0.693 | training accuracy: 52.23%
validation loss: 0.677 |  validation accuracy: 57.58%

epoch number: 2 | time elapsed: 19.76072382926941s
train time = 17.982912063598633, validation time = 1.7778115272521973
training loss: 0.681 | training accuracy: 55.93%
validation loss: 0.641 |  validation accuracy: 63.58%

epoch number: 3 | time elapsed: 20.421406507492065s
train time = 17.444645166397095, validation time = 2.9767611026763916
training loss: 0.631 | training accuracy: 64.18%
validation loss: 0.644 |  validation accuracy: 69.56%

epoch number: 4 | time elapsed: 23.67748498916626s
train time = 20.439718008041382, validation time = 3.237766742706299
training loss: 0.588 | training accuracy: 69.00%
validation loss: 0.521 |  validation accuracy: 76.07%

epoch number: 5 | time elapsed: 20.615106344223022s
train time = 18.653815031051636, validation time = 1.9612908363342285
training loss: 0.523 | training accuracy: 74.42%
validation loss: 0.484 |  validation accuracy: 77.92%

epoch number: 6 | time elapsed: 19.7037615776062s
train time = 17.35865044593811, validation time = 2.3451108932495117
training loss: 0.471 | training accuracy: 77.90%
validation loss: 0.567 |  validation accuracy: 76.21%

epoch number: 7 | time elapsed: 20.183289527893066s
train time = 17.201263189315796, validation time = 2.9820258617401123
training loss: 0.410 | training accuracy: 81.78%
validation loss: 0.430 |  validation accuracy: 82.45%

epoch number: 8 | time elapsed: 20.243817806243896s
train time = 17.51808476448059, validation time = 2.7257325649261475
training loss: 0.365 | training accuracy: 84.34%
validation loss: 0.434 |  validation accuracy: 83.85%

epoch number: 9 | time elapsed: 20.38038659095764s
train time = 17.858877897262573, validation time = 2.52150821685791
training loss: 0.329 | training accuracy: 86.10%
validation loss: 0.359 |  validation accuracy: 87.13%

epoch number: 10 | time elapsed: 20.606106519699097s
train time = 18.336904287338257, validation time = 2.2692017555236816
training loss: 0.306 | training accuracy: 87.43%
validation loss: 0.425 |  validation accuracy: 85.91%

なんと、LSTMでもGRU同様平均20秒くらい。
結局学習時間でもLSTMとGRUは大差ないようです。
最初にたくさん時間がかかっていた理由は、何かモデルをロードする部分に時間がかかっていたのでしょうか。
しかし、最初のLSTMはエポック10までずっと120secくらいかかっていたので謎です。
裏で何か計算していたわけでもないので、ちょっと理解できません。

テストに対する定量比較

テストに対する結果は、こちらのようになりました。

LSTM
test loss: 0.351 | test accuracy: 84.97%
GRU
test loss: 0.401 | test accuracy: 86.40%

LSTM、GRUで同程度ですね。

サンプル入力に対する定性比較

まずは教科書に乗っているテキストに対する入力と出力がこちらです。
これをテスト1と命名しましょう。

test1
print(sentiment_inference(lstm_model, "This film is horrible"))
print(sentiment_inference(lstm_model, "Director tried too hard but this film is bad"))
print(sentiment_inference(lstm_model, "Decent movie, although could be shorter"))
print(sentiment_inference(lstm_model, "This film will be houseful for weeks"))
print(sentiment_inference(lstm_model, "I loved the movie, every part of it"))
LSTM
0.4319622814655304
0.04400959238409996
0.266744464635849
0.7558539509773254
0.9776954054832458

数値は、sigmoidで[0,1]に飛ばした結果で、高評価度を表していると思ってください。
良いというコメントには、高い値が出ていますね。まずまずの結果です。

次に、テスト1のGRUにたいする結果がこちら。

GRU
0.07484859228134155
0.01621958799660206
0.2282780110836029
0.5675244331359863
0.9830831289291382

LSTM同様、割とできていますね。

次に、独自に、意味のない言葉や、コンピュータを騙すような入力を入れてみた結果がこちらです。
これをテスト2と命名しましょう。

test2
print(sentiment_inference(lstm_model, "hogehoge"))
print(sentiment_inference(lstm_model, "fugafuga"))
print(sentiment_inference(lstm_model, "I like it, but my friend hate it."))
print(sentiment_inference(lstm_model, "I hate it, but my friend likes it."))
print(sentiment_inference(lstm_model, "I love you."))
print(sentiment_inference(lstm_model, "I love this movie, by the way I hate the book."))
print(sentiment_inference(lstm_model, "I hate this movie, by the way I love the book."))
print(sentiment_inference(lstm_model, "With best New Year’s wishes!"))
LSTM
0.49046969413757324
0.49046969413757324
0.9504886865615845
0.942279577255249
0.9454897046089172
0.8072608113288879
0.8895605802536011
0.963241696357727

4番目のように、嫌いだ、といっているコメントにも高い数値がでており、完全に騙されていますね。
また、全然関係のない言葉を入れた場合も、ポジティブな言葉があれば高い数値が出ています。
このことから、文の構造は理解できておらず、単語だけで判定していそうなことが分かります。

次にGRUでのテスト2の結果がこちら

GRU
0.5029202103614807
0.5029202103614807
0.9278048276901245
0.9288939237594604
0.955331027507782
0.9214895963668823
0.9193702340126038
0.9812726974487305

LSTMと同じように騙されていますね。
こちらも構造は理解できておらず、単語だけで判断していそうです。

おわりに

結論にも書いたように、性能に大差はないので、小規模なデータにはシンプルなGRUを使うのが良さそうです。
但し、GRUのwikipedia

LSTMセルは「ニューラル機械翻訳のためのアーキテクチャ変法の初の大規模分析」においてGRUセルを一貫して上回った

と書かれているように、より大規模なデータセットに対しては、LSTMの方が性能がいいようなので、transformerを使うほどでもない場合は、コードはGRUとほぼ同様にかけるという利点があるので、LSTMを使うのが良さそうです。

以上、ちょっと長くなってしまいましたが、LSTMとGRUの比較でした。

最後にこの歌でお別れしましょう。
♫田舎のじっちゃんばっちゃん芋食ってへーこいてパンツが破れて飛んでった、ヘイ!♫
皆様、良いお年を。

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7