1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ゼロから作るDeep Learning❷で素人がつまずいたことメモ:8章

Last updated at Posted at 2021-02-11

はじめに

ふと思い立って勉強を始めた「ゼロから作るDeep Learning❷ーー自然言語処理編」の8章で私がつまずいたことのメモです。

この章の実行環境は前章とは異なり、すべてGoogle Colaboratoryです。

(このメモの他の章へ:1章 / 2章 / 3章 / 4章 / 5章 / 6章 / 7章 / 8章 / まとめ

この記事は個人で作成したものであり、内容や意見は所属企業・部門見解を代表するものではありません。

8章 Attention

この章はAttentionの解説です。なお、今回も自分で実装する時間が取れなかったので、本の実装を試すレベルに留まっています。ご了承ください。

8.1 Attentionの仕組み

seq2seqの問題点を改善する流れでAttentionの仕組みが解説されていて、分かりやすかったです。実装についてもかなり細かく解説されています。

Attentionのような仕組みのアイデアが浮かんだ時に(私のような凡人では思い浮かびようもないのですが)、微分可能な演算を駆使してニューラルネットワークを組み上げることができれば、あとは誤差逆伝播法でなんとかなる、みたいな考え方は面白いです。

8.2 Attention付きseq2seqの実装

seq2seqへの組み込みは、pythonの継承の仕組みのおかげで簡単そうです。

8.3 Attentionの評価

本と同じものを試しても面白くないので、漢数字からアラビア数字への変換をやってみることにしました。例えば一億二千五百五十七万125570000に変換するようなタスクです。

ちょっと簡単すぎるような気もしますが、途中がゼロの部分は漢数字だと文字がなくなりますし、は省略されたりされなかったりするので(例えば1110一千一百一十ではなく千百十になりますが、1億は一億のままで省略されない)、意外に複雑そうです。

今回は、値の範囲を1〜9,999,999,999,999(九兆九千九百九十九億九千九百九十九万九千九百九十九)として、5万件のデータセットを生成することにしました。

データセットの準備

データセットは前章のようにプログラムで生成しようと思っていたのですが、なんとExcelのNUMBERSTRING()という関数がアラビア数字を漢数字へ変換してくれることがわかったので、今回はそれを使うことにしました。

まずExcelのセルA1の式です。

Excel[A1]
=RANDBETWEEN(1,POWER(10,RANDBETWEEN(5,13))-1)

ちょっと分かりにくいのですが、単純に乱数にすると全体の9割が1兆越えで大きな値ばかりになってしまうので、まず上限の桁数を乱数で決めて(RANDBETWEEN(5,13)の部分)、その桁数が上限になる最大値を求めて(POWER(10,RANDBETWEEN(5,13))-1の部分)、1からその最大値の範囲で乱数を作る形にしてみました。

次に右隣のセルB1の式です。

Excel[B1]
=LEFT(NUMBERSTRING(A1,1)&REPT(" ",25),25)&"_"&LEFT(A1&REPT(" ",13),13)

A列の乱数に対して、NUMBERSTRING(A1,1)で漢数字の表現に変換します。そして、漢数字は空白でパディングして25文字(LEFT(NUMBERSTRING(A1,1)&REPT(" ",25),25)の部分)に、アラビア数字は13文字(LEFT(A1&REPT(" ",13),13)の部分)に揃えて_で連結します。

このA1セルとB1セルの式を5万行に埋めます。埋め終わるとこんな感じになります。
excel.png
次にB列を縦に選んでコピーし、テキスト エディターにペーストして、dataset/number_str.txtとして保存すれば完成です。先頭部分はこんな感じになります。

dataset/number_str.txt
六億六千七百七十万二千八百四十六         _667702846    
二千六百五万五千八百三十六            _26055836     
三兆七千五百二十一億三千七百十一万五千百二十三  _3752137115123
四百六十五万千五百十五              _4651515      
八兆千九十五億七千八百六十四万八千九百二十三   _8109578648923
四百七十二億四千九万四千二百           _47240094200  
九兆三千三百七億六千五百五万八千二百五十     _9330765058250
八万三千六百六十三                _83663        
百八十二億七千四百二十三万六十一         _18274230061  
九億七千七百三十六万三千一            _977363001    
四十五万六千三百二十五              _456325       
四兆六千百二十億五千八百三十五万七千六百二十二  _4612058357622
三万千六百三十七                 _31637        
五兆五千四百十六億六千五十二万九千百三十四    _5541660529134
六千六十七億七千六百三十二万八千七百三十九    _606776328739 
九百七十六万七千七百四十六            _9767746      
七千五百六十三億五十万九千五百十九        _756300509519 
七万二千八百八十六                _72886        
五千五百二十一万八千九百六十三          _55218963     
百十八万五千二百四十七              _1185247      

なお、文字数が揃っていないように見えるのは半角の空白でパディングしていて空白の文字幅が小さいためです。

Attention付きseq2seqの学習

データセットをすり替えるので、ch08/train.pyを少し変更します。の部分が変更点です。

ch08/train.py
# coding: utf-8
import sys
sys.path.append('..')
sys.path.append('../ch07')  # ★ch07/peeky_seq2seq.py内でseq2seq.pyのインポートに失敗するので追加
import numpy as np
import matplotlib.pyplot as plt
from dataset import sequence
from common.optimizer import Adam
from common.trainer import Trainer
from common.util import eval_seq2seq
from attention_seq2seq import AttentionSeq2seq
from ch07.seq2seq import Seq2seq
from ch07.peeky_seq2seq import PeekySeq2seq


# データの読み込み
(x_train, t_train), (x_test, t_test) = \
    sequence.load_data('number_str.txt')  # ★データを変更
char_to_id, id_to_char = sequence.get_vocab()

# 入力文を反転
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]

# ハイパーパラメータの設定
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256
batch_size = 128
max_epoch = 10
max_grad = 5.0

model = AttentionSeq2seq(vocab_size, wordvec_size, hidden_size)
# model = Seq2seq(vocab_size, wordvec_size, hidden_size)
# model = PeekySeq2seq(vocab_size, wordvec_size, hidden_size)

optimizer = Adam()
trainer = Trainer(model, optimizer)

acc_list = []
for epoch in range(max_epoch):
    trainer.fit(x_train, t_train, max_epoch=1,
                batch_size=batch_size, max_grad=max_grad)

    correct_num = 0
    for i in range(len(x_test)):
        question, correct = x_test[[i]], t_test[[i]]
        verbose = i < 10
        correct_num += eval_seq2seq(model, question, correct,
                                    id_to_char, verbose, is_reverse=True)

    acc = float(correct_num) / len(x_test)
    acc_list.append(acc)
    print('val acc %.3f%%' % (acc * 100))


model.save_params()

# グラフの描画
x = np.arange(len(acc_list))
plt.plot(x, acc_list, marker='o')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.ylim(-0.05, 1.05)
plt.show()

なお、4行目のsys.path.append('../ch07')は、普通に本のコードのままで実行すると次のようなModuleNotFoundErrorになってしまうので追加しました。

ModuleNotFoundError
Traceback (most recent call last):
  File "/Users/segavvy/Documents/deep-learning-from-scratch-2/ch08/train.py", line 12, in <module>
    from ch07.peeky_seq2seq import PeekySeq2seq
  File "../ch07/peeky_seq2seq.py", line 5, in <module>
    from seq2seq import Seq2seq, Encoder
ModuleNotFoundError: No module named 'seq2seq'

インポート対象のch07/peeky_seq2seq.pyは、中でさらに同じ場所にあるseq2seq.pyをインポートしようとしますが、ch07にはパスが通っていないため見つけられないことが原因のようです。

学習の実行は今回も1Google Colabを使いました。今回はGPUを使わないのですが、それでも手元の環境より早いので快適です。

まず、1エポック目の結果です。

epoch1
| epoch 1 |  iter 1 / 351 | time 0[s] | loss 3.30
| epoch 1 |  iter 21 / 351 | time 15[s] | loss 2.68
| epoch 1 |  iter 41 / 351 | time 33[s] | loss 1.93
| epoch 1 |  iter 61 / 351 | time 47[s] | loss 1.92
| epoch 1 |  iter 81 / 351 | time 61[s] | loss 1.85
| epoch 1 |  iter 101 / 351 | time 74[s] | loss 1.81
| epoch 1 |  iter 121 / 351 | time 88[s] | loss 1.78
| epoch 1 |  iter 141 / 351 | time 102[s] | loss 1.76
| epoch 1 |  iter 161 / 351 | time 115[s] | loss 1.76
| epoch 1 |  iter 181 / 351 | time 129[s] | loss 1.74
| epoch 1 |  iter 201 / 351 | time 143[s] | loss 1.73
| epoch 1 |  iter 221 / 351 | time 157[s] | loss 1.72
| epoch 1 |  iter 241 / 351 | time 170[s] | loss 1.68
| epoch 1 |  iter 261 / 351 | time 184[s] | loss 1.66
| epoch 1 |  iter 281 / 351 | time 198[s] | loss 1.66
| epoch 1 |  iter 301 / 351 | time 211[s] | loss 1.64
| epoch 1 |  iter 321 / 351 | time 225[s] | loss 1.62
| epoch 1 |  iter 341 / 351 | time 239[s] | loss 1.61
Q 三万八千六百五十七                
T 38657        
☒ 1130300      
---
Q 五百六十五億三千七百十七万二千三百十三      
T 56537172313  
☒ 3433444444   
---
Q 十七万二千百三十                 
T 172130       
☒ 1101000000   
---
Q 八万千七百八十四                 
T 81784        
☒ 1101000000   
---
Q 九千百十                     
T 9110         
☒ 1000000      
---
Q 七百七十五億二百七十二万七百四十一        
T 77502720741  
☒ 34344440000  
---
Q 一億二百六十五万三千二百三十四          
T 102653234    
☒ 34344440     
---
Q 七百八億七千七百四十三万五百九          
T 70877430509  
☒ 343334400    
---
Q 五億六千五百八万千二百三十一           
T 565081231    
☒ 343344000    
---
Q 七億三千九百六十万八千七百十八          
T 739608718    
☒ 343444400    
---
val acc 0.000%

なんと全滅ですが、続けてみます。

epoch10
| epoch 10 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 10 |  iter 21 / 351 | time 14[s] | loss 0.00
| epoch 10 |  iter 41 / 351 | time 29[s] | loss 0.00
| epoch 10 |  iter 61 / 351 | time 43[s] | loss 0.00
| epoch 10 |  iter 81 / 351 | time 57[s] | loss 0.12
| epoch 10 |  iter 101 / 351 | time 72[s] | loss 0.88
| epoch 10 |  iter 121 / 351 | time 86[s] | loss 0.20
| epoch 10 |  iter 141 / 351 | time 101[s] | loss 0.02
| epoch 10 |  iter 161 / 351 | time 115[s] | loss 0.01
| epoch 10 |  iter 181 / 351 | time 130[s] | loss 0.01
| epoch 10 |  iter 201 / 351 | time 144[s] | loss 0.00
| epoch 10 |  iter 221 / 351 | time 158[s] | loss 0.00
| epoch 10 |  iter 241 / 351 | time 173[s] | loss 0.00
| epoch 10 |  iter 261 / 351 | time 187[s] | loss 0.00
| epoch 10 |  iter 281 / 351 | time 202[s] | loss 0.00
| epoch 10 |  iter 301 / 351 | time 216[s] | loss 0.00
| epoch 10 |  iter 321 / 351 | time 230[s] | loss 0.00
| epoch 10 |  iter 341 / 351 | time 245[s] | loss 0.00
Q 三万八千六百五十七                
T 38657        
☑ 38657        
---
Q 五百六十五億三千七百十七万二千三百十三      
T 56537172313  
☑ 56537172313  
---
Q 十七万二千百三十                 
T 172130       
☑ 172130       
---
Q 八万千七百八十四                 
T 81784        
☑ 81784        
---
Q 九千百十                     
T 9110         
☑ 9110         
---
Q 七百七十五億二百七十二万七百四十一        
T 77502720741  
☑ 77502720741  
---
Q 一億二百六十五万三千二百三十四          
T 102653234    
☑ 102653234    
---
Q 七百八億七千七百四十三万五百九          
T 70877430509  
☑ 70877430509  
---
Q 五億六千五百八万千二百三十一           
T 565081231    
☑ 565081231    
---
Q 七億三千九百六十万八千七百十八          
T 739608718    
☑ 739608718    
---
val acc 99.780%

特にハイパーパラメーターも調整しないまま、10エポックで十分学習できました(と、この時点では思っていました)。
image01.png

Attentionの可視化

Attentionの可視化に挑戦する前に、データセットに漢字を使ってしまったのでmatplotlibの日本語対応が必要です。これをやらないと、グラフで日本語が表示できません。

(脱線)Google Colabのmatplotlibで日本語を表示するための準備

ここではGoogle Colabでの流れをざっとまとめますので、詳細については、@siraasagiさんのGoogle Colabでまた日本語表示が豆腐不可避な方になどをご参照ください。

  1. まず日本語のIPAフォントをインストールします。
!apt-get -y install fonts-ipafont-gothic
  1. フォント キャッシュのクリアが必要なため、まず名前を調べます。ここで、fontlist-で始まるファイルを確認してください。
!ls -ll /root/.cache/matplotlib/
  1. そのファイルを削除します。
rm /root/.cache/matplotlib/fontlist-v310.json
  1. 最後にGoogle Colabのランタイムをリスタートしてください。これで、プログラムでIPAフォントを指定すれば日本語が使えるようになります。私が実行した時の画面を貼り付けておきます。
    googlecolab_font.png

続いてAttentionを可視化するためのコードch08/visualize_attention.pyの修正です。本のコードはデータセットの中からランダムに5つのデータを選んで表示するのですが、入力内容を自由に試したかったので、標準入力で任意の文字を入れて結果が見られるようにしてみました。の部分が変更点です。

ch08/visualize_attention.py
# coding: utf-8
import sys
sys.path.append('..')
import numpy as np
from dataset import sequence
import matplotlib.pyplot as plt
from attention_seq2seq import AttentionSeq2seq

# ★Google Colabのmatplotlibで日本語表示するためにフォント設定
import seaborn as sns
sns.set(font='IPAGothic')

(x_train, t_train), (x_test, t_test) = \
    sequence.load_data('number_str.txt')  # ★データを変更
char_to_id, id_to_char = sequence.get_vocab()

# Reverse input
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]

vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256

model = AttentionSeq2seq(vocab_size, wordvec_size, hidden_size)
model.load_params()

_idx = 0
def visualize(attention_map, row_labels, column_labels):
    fig, ax = plt.subplots()
    ax.pcolor(attention_map, cmap=plt.cm.Greys_r, vmin=0.0, vmax=1.0)

    ax.patch.set_facecolor('black')
    ax.set_yticks(np.arange(attention_map.shape[0])+0.5, minor=False)
    ax.set_xticks(np.arange(attention_map.shape[1])+0.5, minor=False)
    ax.invert_yaxis()
    ax.set_xticklabels(row_labels, minor=False)
    ax.set_yticklabels(column_labels, minor=False)

    global _idx
    _idx += 1
    plt.show()


x_len = len(x_test[0])      # ★入力文字数
t_len = len(t_test[0]) - 1  # ★出力文字数、-1は区切り文字の分
start_id = t_test[0][0]     # ★区切り文字のID
while True:

    # ★入力文を標準入力する形に変更
    query = input('入力?(漢数字表記)')
    if not query:
        break

    # ★x_testの先頭を入力内容ですり替え
    chars = list(f'{query: <{x_len}}')[::-1]
    x_test[0] = np.array([char_to_id[i] for i in chars], dtype=np.int)
    x = x_test[[0]]

    # ★t_testの先頭は推測結果にすり替え
    guess = model.generate(x, start_id, t_len)
    t_test[0] = [start_id] + guess
    t = t_test[[0]]
    # ★以降は既存のソースそのまま
 
    model.forward(x, t)
    d = model.decoder.attention.attention_weights
    d = np.array(d)
    attention_map = d.reshape(d.shape[0], d.shape[2])

    # reverse for print
    attention_map = attention_map[:,::-1]
    x = x[:,::-1]

    row_labels = [id_to_char[i] for i in x[0]]
    column_labels = [id_to_char[i] for i in t[0]]
    column_labels = column_labels[1:]

    visualize(attention_map, row_labels, column_labels)

以下、いろいろ試した結果です。まず、上手く変換できている例から。
all.png
small.png
one.png
zero.png
桁が多いものも少ないものもいい感じです。桁が飛んで0になる部分は、その「単位」の部分の重みが大きくなっている(白くなっている)ことが分かります。ただ、本の例とは異なり、前方から順番に対応していくだけなのが少しシンプルでつまらないですが。

続いて失敗した例です。
err01.png
「一」を変換させたら「0009009000」という謎の数字になりました。
err02.png
「一億一」は「1 000 000 000 400」(1兆400)です。だいぶ大きさが変わってしまいました。
err03.png
「三兆九千億」は「3 900 000 003 339」(3兆9千億3,339)です。終わりの方に同じ数字が出てきてしまいました。
err04.png
「五兆五千五十五」は「55 055」とかなり小さくなってしまいました。

失敗の原因は、データセットを乱数で作ってしまったために、今回失敗したような特徴的なデータがほとんど含まれなかったことだと思います。極端に小さなものや、途中の桁が0で埋まるものや、同じ数字が何度も出てくるようなものが少なく、特徴を上手く学習できなかったのでしょう。当たり前のことですが、対象データに特徴的なものが含まれる場合は、適当にサンプリングしてはいけません:sweat:

また、Attentionの注目部分を見ると、失敗した箇所は入力の文字の末尾(右端)のパディング部分にも注目しているようです。「7.2.3 可変長の時系列データ」にも説明がありましたが、パディングに使っている空白は本当は損失に計上してはいけないので、その辺りも失敗に影響しているかも知れません。

8.4 Attentionに関する残りのテーマ

Attentionの改良方法として、双方向RNNや深層化が紹介されています。

8.5 Attentionの応用

最後にAttentionの応用例が3つ紹介されています。ここ数年でよく聞くTransformerも含まれています。

8.6 まとめ

失敗談のメモみたいになってしまいましたが、Attentionを学ぶことができました。
この章は以上です。誤りなどありましたら、ご指摘いただけますとうれしいです。

(このメモの他の章へ:1章 / 2章 / 3章 / 4章 / 5章 / 6章 / 7章 / 8章 / まとめ

  1. Google Colabについては、このメモの6章に使い方の流れをまとめています。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?