新しい仕組みを試してみた理由と工夫点。
画像キャプショニングに Mask Predict を導入して学習、推論を行って、一定の精度を得たことを、記事
で報告しました。ここから、非自己回帰型で精度の改善ができないかということを模索していました。downsample による学習、推論
の改良を試みましたが、うまくいかなかったため、Mask Predict を基本に改善できないか考えました。Mask Predict ですが、損失の計算に CTCLoss を使う方法を考えました。二つの工夫が必要でした。一つ目は、デコーダーの q 入力(target 入力)である caption を、埋め込み(embedding)した後に、系列長方向(時間方向)に2倍程度の長さに upsample すること。二つ目は、推論時に <mask> の数を徐々に減らした q 入力を用いて、デコーダーを10回程度イタレーションしますが、l + 1 回目のデコーダーへの q 入力を作成するために、l 回目のイタレーションの出力をそのままでは使えません。なぜなら、q 入力はセンテンスを数値に置き換えた形式であり、この入力をデーコーダーの中で、CTC 用に uspample しているからです。ですから、l回目のイタレーションの出力をセンテンスを埋め込んだ数字列までデコードして、確率を考慮して <mask> をかけ、l + 1 回目のイタレーションの q 入力を作成する必要があります。このため、CTC_simple_inference という関数を新たに作りました。CTC_simple_inference 関数の入力は、デコーダーの出力であり、出力は、センテンス数字列とその確率です。
元のプログラムとのつながり
「Python で学ぶ画像認識」の画像キャプショニングのプログラムをもとに改修しました。
プログラムの土台は、「Ptyhon で学画像認識」の画像キャプショニングのプログラムです。このプログラムに2段階で改修をさせていただきました。1段目は、自己回帰型から Mask Predict の非自己回帰型への改修です。この改修については、この記事の一番上の URL で説明させていただきました。二段目が今回の改修で、Mask Predict でありながら、キャプションの損失を CTCLoss で計算するように改修しました。
一段目の改修で、推論するキャプションの長さを学習、推論するようにプログラムを改修しました。長さを推論するニューラルネットワークを length_predictor とします。
データは、cocodataset の train2017 です。
データは、cocodataset の train2017 を、学習データ:90%、validation データ:10% としました。
辞書の修正
CTCLoss を扱うときは、通常、token_id の 0 を <blank> に割り当てます。また、Mask Predict の場合、token に <mask> を準備する必要があります。このため、本の 6_2_dataset.ipynb を改修しました。word_to_id と id_to_word の token_id = 0 を開けるために、同プログラムの id の採番を i から i + 1 に修正し、学習・推論プログラムで、word_to_id、id_to_word をファイルから読み込んだら、これらの辞書を使用する前に、word_to_id['<blank>'] = 0 と id_to_word[0] = '<blank>' を入れました。また、tokenに <mask> を入れるために、6_2_dataset.ipynb の vocab.append の最後にvocab.append( '<mask>' ) を入れました。
学習について
損失の定義と損失の推移
loss = loss0 + loss1
としました。デコーダーの出力を outputs、その長さを output_lens、キャプションを captions、キャプションの長さを caption_lens とすると、
loss0 = nn.CTCLoss()( F.log_softmax( outputs ).transpose(0,1), captions, output_lens, caption_lens )
です。通常の Mask Predict のキャプション損失は、Mask をかけた時間的位置だけですが、CTCLoss では、outputs の時間的推移も大事なようなので、Mask を考慮せず、全部にしました。loss1 は、length_predictor の出力を lengths として
loss1 = nn.MSELoss()( lengths, caption_lens.to(torch.float32) )
です。loss0 が重要です。 loss1 と loss は補助的に見てください。
WER の推移
Word Error Rate WER の推移を掲載させていただきます。
推論結果
<start> a dog holds a mouth up to catch a frisbee in mouth <end>
<start> a herd of giraffe standing in a field together in the background <end>
<start> a large jetliner flying over the side of a lush green field <end>
<start> a car driving down a street with a car on the it <end>
<start> a man riding a surfboard in the middle of a large wave <end>
<start> a group of people riding on a horse next to a beach <end>
<start> a cat sitting on the side of a street near a building <end>
<start> a bowl of fruit and a banana sitting on a table top <end>
<start> a woman in a bikini riding her bike on the beach shore <end>
<start> a dog in the air with a frisbee in its mouth open <end>
<start> a baseball player holding a bat while standing on the baseball field <end>
<start> a man in a suit on a snow board in the snow <end>
<start> a group of people and bathing suits standing on the beach with surfboards <end>
<start> a large building with a clock and a clock on it tower <end>
<start> a man riding a skateboard on a ramp on a sunny day <end>
<start> a bird perched on a branch with a tree in the background <end>
<start> a man in a suit holding a doughnut in a right hand <end>
<start> a brown horse standing in a field next to a forest <end>
<start> a zebra standing on top of a green field in a grass <end>