はじめに
音声合成と音声認識について transformer を基本とした非自己回帰型のプログラムを報告させていただきました。
共通の構成は、transformer-encoder + downsampling( upsampling ) + transformer-decoder ( + CTCLoss ) です。今回は、機械翻訳(日本語を英語に翻訳)にこの構成を応用して、非自己回帰型プログラムを試しました。validation data の WER が 37 % を得ましたので、それなりに学習しているのでは、ということでご報告させていただきます。
データ
データは、日本語英語コーパス
のデータを使わせていただきました。train 20万件、val 5000件、test 5000件としました。token の単位は単語です。私の環境では、英語文に multi-byte 文字が入っていると、idx_to_word.json ファイルの分かち書きの英語が文字化けするので、multi-byte 文字の入った英文は学習に使いませんでした。idx_to_word.json(index と英語分かち書きの対応ファイル)、 idx_to_wakati.json(index と日本語分かち書きの対応ファイル)も作成しました。
学習曲線
学習は、20万件を batch size = 16 で 25 エポック行いました。10エポック程度以降は過学習のようです。学習曲線を掲載します。一番目が損失と学習回数。二番目が誤り率(WER)と学習回数です。
学習について
9/25 エポック目のデータを掲載します。validation japanese が日本語文です。validation reference が英語の教師データです。validation hypothesis が英語の推論データです。 validation loss が損失、validation token eror rate が単語誤り率です。 <sos> が start of sentence で、<eos> が end of sentence です。
validation, japanese :<sos>あなたはこの記事のこれらの質問に対する回答を見つけ
るでしょう。<eos>
validation, reference :<sos> you will find answers to these questions in this article.<eos>
validation, hypothesis:<sos> you will find answers to these questions this article.<eos>
validation, japanese :<sos>あなたはハーモニカを得るでしょう。<eos>
validation, reference :<sos> you will get harmonica.<eos>
validation, hypothesis:<sos> you will get harmonica.<eos>
validation, japanese :<sos>あなたはコーンのようなものを手に入れるでしょう。<eos>
validation, reference :<sos> you will get something like a cone.<eos>
validation, hypothesis:<sos> you will get something like cones.<eos>
validation, japanese :<sos>ページから広告を削除したい場合にのみ設定を掘る必要
があります。<eos>
validation, reference :<sos> you will have to dig in the settings only if you want to remove advertising from pages.<eos>
validation, hypothesis:<sos> only only dig the only if you want to remove advertising from the page.<eos>
validation loss: 0.127414
validation token error rate: 36.783819 %
参考のため、訓練用データについても学習結果を掲載しておきます。
train, japanese :<sos>魅力の方法の1つは、価格表示でメニューの広告シール>ドに検討し、配置することができます。<eos>
train, reference :<sos> one of the methods of attraction can be considered and placement on the advertising shield of the menu with the price indication.<eos>
train, hypothesis:<sos> one of the methods of attraction can be considered and placed on the advertising shield and the shield of the menu with the price display.<eos>
train, japanese :<sos>きっと、少なくとも、、顧客に提供するのに十分な量を
作成するだけの時間はあることでしょう。<eos>
train, reference :<sos> they seem to have enough time to create enough content to get by.<eos>
train, hypothesis:<sos> they have to have enough time of create enough content offer offer by.<eos>
train, japanese :<sos>これらを通じて、改めて沖縄の魅力を掘り起こし、その
エネルギーを全国、そして世界へと発信してまいります。<eos>
train, reference :<sos> we hope to transmit the charm and the power of okinawa towards the rest of japan and the world.<eos>
train, hypothesis:<sos> we will introduce the charm and the okinawa of the energy and and the world.<eos>
train, japanese :<sos>登録行動(取引を行う場合、住宅パラメータなどを変更>する際に)についての各訴訟の事実、記録は単一のデータベースに行われ、将来的には要>求に応じて最も関連性の高い情報が得られます。<eos>
train, reference :<sos> upon the fact of each appeal for any registration actions ( when making a transaction , making changes to housing parameters , etc. ) , records are made to a single database , and in the future , on request , citizens can get the most relevant information.<eos>
train, hypothesis:<sos> for the action of fact any action of ( when when changing , housing parameters , etc. ) , records is made to a single database , and in the future future it will get the most relevant information on.<eos>
train loss: 0.064757
train token error rate: 24.438091 %
ここで、train データについては、model.forward を使って、逆伝播のために loss を計算するのに使う ouputs を計算し、その outputs で hypothesis と error_rate ( 文章における単語のlevenshtein 距離をtotal_token で割って 100 をかけたもの)を計算している。理想的には validation データについては学習に使わないので、推論用の関数 model.inference を使ってoutputs を計算し、hypothesis と error_rate を計算すべきであるが、train データと同じ関数で計算した。今回のアルゴリズムでは、学習用の model.forward と推論用の model.inference にアルゴリズムとしての差はない。
学習終了時点での情報は、
---------------Summary------------------
Final epoch model -> ./exp_train_large/char_model_conv_non_ar_007/final_model.pt
train loss: 0.011159
train token error rate: 4.222117 %
validation loss: 0.189154
validation token error rate: 37.396498 %
Best epoch model (6-th epoch) -> ./exp_train_large/char_model_conv_non_ar_007/best_model.pt
train loss: 0.121599
train token error rate: 40.910098 %
validation loss: 0.118504
validation token error rate: 37.615678 %
です。
vlaidation loss が一番良かった 6/20 のエポックのモデルパラメーターとテスト用のデータを用いて計算した WER が
(base) ...:~/translateCTC$ python3 04_scoring.py
WER: 36.99% (SUB: 18.16, DEL: 15.67, INS: 3.16)
と37%です。
プログラムの特徴
transformer の position wise feed forward network の代わりに convolution を使いました。convolution は 2層で、入力と出力の channel 数は (1024, 4096), (4096,1024) です。kernel数は、5 と 1 です。dropout は0.1 です。encoder は、12層で次元は1024で、head数は8です。。decoder も12層で次元は1024 で、head 数は8です。encoder の入力は、日本語文を数字に置き換えたものです。encoder と decoder の間に upsampling モジュールがあります。upsampling モジュールは、encoder の出力を時間方向について seuquence 2 倍の長さにしました。この時、encoder 入力の長さ(リストのinputs.size(1) ではなく、input_lengths[n] のこと)を二倍にして、CTCLoss の入力の output_lengths にしました。encoder の出力を decoder のk,v入力に、upsampling の出力を decoder のq入力にしました。損失には、CTCLoss を使いました。
loss = nn.CTCLoss(blank=0, reduction='mean',zero_infinity=False)( outputs.transpose(0,1), labels, output_lengths, labels_lens )
ここで、outputs と output_lengths は、学習モデルの出力。labels と labels_lens は、教師データです。lables は、英語文を数値化したものです。CTCLoss の入力は、outputs の時間方向の sequence 数がlabels の時間方向の sequence 数よりある程度大きくないと、および、output_lengths が labels_lens に比べてある程度大きくないと inf を出力するようです。このため、pad_id で埋めることにより outputs.size(1) を labels.size(1) の1.5倍以上に、outputs_lengths[n] も labels_lens [n] の1.5倍以上になるように調整しました。Loss に CTCLoss を用いたため、model.inference の出力である outputs を 「Python で学ぶ音声認識」という教科書の ctc_simple_decode(int_vector, token_list) 関数でデコードして翻訳英語文としました。
学習に用いたプログラムとデータを github に置いておきます。
現在のプログラム
005 系です。_old なしのモジュールになります。
以前のプログラム 014 系と _old.py
音声認識のプログラムを使いまわしたので、upsampling とすべきところが、downsampling になっています。ご容赦ください。
CTCLoss 計算時に outputs の time sequence 数が小さくて nanあるいは inf になる問題
上記計算では、outputs を 0 padding して、長さを伸ばして対応した。しかし、upsampling_rate を 2 から 3にすれば解決することが、後で分かった。
追記 upsampling_rate = 3 での計算結果
Loss の推移
訓練データについての Token Error Rate の推移
評価データについての Token Error Rate の推移
36.6%台まで減少しました。
最終エポックの train データ推論
train, japanese :<sos>ヘナは、その抗菌性のために、抗菌特性のために、子細を取り除くのを助け、頭皮の状態を改善します。<eos>
train, reference :<sos> henna helps to get rid of dandruff and improves the condition of the scalp , due to its antibacterial properties.<eos>
train, hypothesis:<sos> henna helps get rid of resources , improve the condition : the scalp , due to its properties.<eos>
train, japanese :<sos>ウサギのフィレットは、サワークリームで、またはローズマリーと障害で消えることができます。<eos>
train, reference :<sos> the rabbit fillet can be extinguished in sour cream or in fault with rosemary.<eos>
train, hypothesis:<sos> the rabbit fillet can be extinguished in sour cream or in rosemary with rosemary.<eos>
train, japanese :<sos>正しいプロポーションでb52を作る方法がわからない場合は、スクロール、カップ、さらにはジャグを使用できます(会社のカクテルを作る場合)。<eos>
train, reference :<sos> if you do not know how to make b 52 in the correct proportions , you can use the scroll , a cup or even a jug ( if you make a cocktail for the company ).<eos>
train, hypothesis:<sos> if you do not know how to make b 52 in the correct proportions , you can use the scroll , a cup or even jug ( if you make a cocktail for the company ).<eos>
train, japanese :<sos>それだけのので、あなたはあなたの最愛の人の忘れられない気持ちを与えることができます。<eos>
train, reference :<sos> only so you can give an unforgettable feeling of your beloved.<eos>
train, hypothesis:<sos> only so you can give an unforgettable feeling of your beloved.<eos>
train loss: 0.038199
train token error rate: 17.277813 %
最終エポックの validation データ推論
validation, japanese :<sos>あなたはこの記事のこれらの質問に対する回答を見つけるでしょう。<eos>
validation, reference :<sos> you will find answers to these questions in this article.<eos>
validation, hypothesis:<sos> you will find answers to these questions this article.<eos>
validation, japanese :<sos>あなたはハーモニカを得るでしょう。<eos>
validation, reference :<sos> you will get harmonica.<eos>
validation, hypothesis:<sos> you will get harmonica.<eos>
validation, japanese :<sos>あなたはコーンのようなものを手に入れるでしょう。<eos>
validation, reference :<sos> you will get something like a cone.<eos>
validation, hypothesis:<sos> you will get like cones.<eos>
validation, japanese :<sos>ページから広告を削除したい場合にのみ設定を掘る必要があります。<eos>
validation, reference :<sos> you will have to dig in the settings only if you want to remove advertising from pages.<eos>
validation, hypothesis:<sos> it is only to make the settings only if you want to remove advertising from the page.<eos>
validation loss: 0.144653
validation token error rate: 36.637699 %
計算の最終表示
---------------Summary------------------
Final epoch model -> ./exp_train_large/char_model_non_ar_conv_005/final_model.pt
train loss: 0.038199
train token error rate: 17.277813 %
validation loss: 0.144653
validation token error rate: 36.637699 %
Best epoch model (15-th epoch) -> ./exp_train_large/char_model_non_ar_conv_005/best_model.pt
train loss: 0.038199
train token error rate: 17.277813 %
validation loss: 0.144653
validation token error rate: 36.637699 %
学習に用いたプログラムを 02_train_transformer_conv_non_ar_005.py, my_model.py, attention.py, encoder.py, decoder.py として 上記 github に置いておく。 padding mask 対応しています。