0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

「Python で学ぶ画像認識」の画像キャプショニングのプログラムを Mask Predict による非自己回帰型に改修しました。

Last updated at Posted at 2024-03-27

改修を行った動機

音声合成、音声認識、機械翻訳、画像キャプショニングなどで、downsample や upsample を用いて Transformer を非自己回帰的に使う方法を提案させていただきました。画像キャプショニングについては精度があまり良くありませんでした。

downsample の非自己回帰型プログラムの精度と自己回帰型プログラムの精度の違いの原因は、Transformer Decoder の q 入力、すなわち target 入力に、キャプションに関連した入力を入れるかどうかのように推察されました。donwsample 非自己回帰型プログラムの target 入力は、エンコーダーの出力を downsample したものであり、キャプションとは関係ありません。一方、Mask Predict の target 入力は、キャプションに mask をかけたものなので、キャプションに関係あります。そこで、Mask Predict による非自己回帰型プログラムではどの程度の精度がでるか試すことが目的になりました。現在のところ、downsample による非自己回帰型プログラムより精度は良さそうです。

学習に用いたデータ

データは本にある cocodataset の val2014 から、精度を確保するために train2017 へと変更しました。train:90 %、validation:10 % としました。

結果

本で demo に用いている画像でキャプショニングを行った結果を掲載させていただきます。

adorable-1849992_1920.jpg

<start> a dog and a dog trying to catch a with <end>

africa-1170179_1920.jpg

<start> giraffes giraffes giraffes in a grassy area near and and <end>

airplane-3702676_1920.jpg

<start> a large air air jet is over airport airport <end>

automotive-1846910_1920.jpg

<start> a red and red red hydrant on a city street <end>

beach-1837030_1920.jpg

<start> a male boarder is a waves in the ocean <end>

caravan-339564_1920.jpg

<start> a man with a umbrella standing on a beach <end>

cat-4467818_1920.jpg

<start> an orange cat sitting on top of a red skateboard <end>

cherry-1468933_1920.jpg

<start> a bowl of a an a a a and and and <end>

couple-955926_1280.jpg

<start> a woman sitting on bike beach a lots of water <end>

dog-7367949_1920.jpg

<start> a man is on a surf board is playing water <end>

hit-1407826_1920.jpg

<start> a baseball player with a ready to hit a ball <end>

man-498473_1920.jpg

<start> a snow boarder down a covered on a ski slope <end>

musician-743973_1920.jpg

<start> a group of people with with boards boards on the beach <end>

port-5788261_1920.jpg

<start> a clock building with a tower on a city street <end>

profile-7579739_1920.jpg

<start> a man doing on with on skate board <end>

ural-owl-4808774_1920.jpg

<start> a small sitting on a tree branch in a tree <end>

wine-bar-2139973_1920.jpg

<start> a man in a black shirt holding a wine glass <end>

woman-3432069_1920.jpg

<start> a woman riding on brown horse through a grassy grassy <end>

zebras-1883654_1920.jpg

<start> a zebra in a grassy area eating and eating <end>

Mask Predict とは。

画像キャプショニングで、Mask Predict により非自己回帰型推論を行う試みは、数少ないようです。ですので、ここで、簡単に Mask Predict についてご説明させていただきます。Mask Predict による文章の推論では、学習時と推論時で異なったアルゴリズムを用います。学習時には、キャプションに乱数で発生させた数だけ、乱数で発生させた位置に <mask> をかけます。<mask> をかけるとは、もとの単語を隠して、word_to_id['<mask>'] = 0 とすることです。マスクをかけた文章をTransformer Decoder の target 入力として、Decoder の計算を行います。Decoder の出力に対応する教師データはキャプション自身です。損失は、マスクをかけた位置だけの損失を計算します。

Mask Predict では、キャプション( Transformer Decoder の target 入力)の一部に <mask> をかけ、その <mask> を推論するような形になるため、<mask> についての token と token_id が必要になります。今回は、本の付録の 6_2_dataset.ipynb で、id を i ではなく、i + 1 で採番し、0 を開けた word_to_id と id_to_word の辞書を作りました。加えて、この辞書をプログラムで使用する前、 word_to_id と id_to_word をファイルから読み込んだあとに、wordk_to_id['<mask>'] = 0 と id_to_word[0] = '<mask>' を定義しました。

Mask Predict で、もう一つ学習することは、文章の長さを予測することです。、今回のプログラムでは、エンコーダーの出力に<len> 位置を設けず、エンコーダーの出力を入力して文章の長さを予測するようなニューラルネットワークを encoder、decoder とは別に構成して、学習させました。

一方、推論時には、Transformer Decoder の target 入力には、予測した文章の長さ分の <start>,<mask>,・・・,<mask>,<end> の文章を用います。まず、最初にこの入力を target 入力として、エンコーダーの出力を memory 入力として、 Transformer Decoder の出力を得ます。この出力は完全ではないので、この出力の確率が低い n 個の単語を <mask> に置き換えて target 入力として、エンコーダーの出力を memory 入力として、もう一度 Transformer Decoder の出力を得ます。これを10回程度イタレーションします。<mask> に置き換える n 個は、10回のイタレーションにより、0になるように線形に減らしていきます。

学習と推論に用いたプログラムを github に置いておきます。

損失と学習曲線、および、Transformer のパラメーター

損失 loss は

loss = loss0 + loss 1

としました。loss0 は、キャプションの CrossEntropyLoss で、loos1 は、caption_lengths の MSELoss です。

loss0、loss1、loss、 WER の学習時の推移のグラフを掲載します。

fig1.png

loss0 が、キャプションの CrossEntropyLoss です。このグラフからすると、まだ学習できそうです。

fig2.png

fig3.png

fig4.png

Transformer Decoder のパラメーターは、隠れ変数の次元が 512、FeedForward の中間次元が 2048、ヘッド数が8、Transformer Decoder の層数が 12 としました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?