改修した点
オリジナルの wav2vec2.0
の Context Network は Transformer Encoder を用いている。今回は、この Context Network に Transformer Encoder + Downsampler + Transformer Decoder を用いて、事前学習と Fine Tuning を行い、学習に使ったデータ量に比較してそれなりに学習できたようなので、結果を報告する。
学習データ
データは、事前学習が JSUT 1.1(3,000発話) + Common Voice Japanese 11 (35,000発話)合計38,000の訓練データ、評価データ が JSUT 1.1 1000発話、テストデータが JSUT 1.1 1000発話である。
Fine Tuning については、訓練データが JSUT 1.1 4700 発話、評価データが JSUT 200発話、tテストデータが JSUT 100 発話である。
事前学習の結果報告
事前学習は、38,000 (正味およそ50時間)データをbatch_size = 8 で、63epochs 実行した。時間は RTX-A6000 一枚で 20 時間程度だった。学習の終了判定については、Train Loss の最小値が3回以上更新されないことを条件とした。学習率は、初期値を 1e-5 として、1epoch 経過したら、1e-5 * 0.01減らした。Train Loss の最小値が更新されない場合、学習率を 1/2 にした。optimizer には、AdamW を用い、eps = 1e-6、weight_decay = 0.1 他はデフォルトとした。
各計算量の推移。
positive_similarity_average, negative_similarity_average
Fine Tuning (データは 4700 発話)の結果報告
Fine Tuning は、4700 の訓練データを batch_size = 8 で、167 epochs 実行した。時間は、RTX-A6000 一枚で 8時間程度だった。学習の終了判定については、Train Loss の最小値が3回以上更新されないことを条件とした。学習率は、初期値を 1e-5 として、Train Loss の最小値が更新されない場合、学習率を 1/2 にした。optimizer には、AdamW を用い、初期学習率以外デフォルトとした。
各計算量の推移
推論結果
Fine Tuning Epoch 167 の訓練データ推論
train, reference :演目の主題である小言をこぼしながら読経する老人は以下のように演じられる
train, hypothesis:演目の主題である小言をこぼしながら読経する老人は以下のように演じられる
train, reference :それらの症状はすぐに治まるがその後生殖能力を失い子を残すことができなくなってしまう
train, hypothesis:それらの症状はすぐに治まるがその後生殖能力を失い子を残すことができなくなってしまう
train, reference :いわゆる英知は単に知識の断片ではないことを心にとめておくべきだ
train, hypothesis:いわゆる英知は単に知識の断片ではないことを心にとめておくべきだら
train, reference :規制緩和が進んでセルフサービスのガソリン給油は値段が少し安くなった
train, hypothesis:規制緩和が進んでセルフサービスのガソリン給油は値段が少し安くなった
Fine Tuning Epoch 167 の評価データ推論
validation, reference :布を斜めに裁ちなさい
validation, hypothesis:無をめに立ちまい
validation, reference :豹はその斑点を変えることはできない
validation, hypothesis:日はその判テを帰ることはでき来い
validation, reference :筆者はそうした風潮を好まない
validation, hypothesis:キー者はそうした封置をこのまないどっ
validation, reference :飛行機は瞬く間に見えなくなった
validation, hypothesis:聞行機は全立くまに言えなくなった
最終結果の表示
Final epoch model -> ./exp_train_large/char_model_wav2vec2.0_069_2/final_model_ft.pt
train loss: 0.053949
train token error rate: 4.412529 %
validation loss: 3.187039
validation token error rate: 41.545265 %
Best epoch model (164-th epoch) -> ./exp_train_large/char_model_wav2vec2.0_069_2/best_model_ft.pt
train loss: 0.053585
train token error rate: 4.403223 %
validation loss: 3.165154
validation token error rate: 41.024974 %
TER の計算結果
$ python3 05_scoring.py
TER: 41.85% (SUB: 28.61, DEL: 10.45, INS: 2.79)
TER は、テストデータに対して計算している。Token Error Rateは同じ計算方法で評価データについて計算したものである。
Fine Tuning (データは 38,000 発話)の結果報告
このセクションは追記である。4700 発話の Fine Tuning で validation データの Token Error Rate があまり良くなかったので、38,000 発話で Fine Tuning してみた結果を報告する。
Loss の推移。
Train データについての Token Error Rate
Validation データについての Token Error Rate
Learning Rate
推論結果
最終エポックの Train データ推論
train, reference :友達とヨーロッパを旅行しようと思ってます
train, hypothesis:友達とヨーロッパを旅行しようと思ってます
train, reference :営業に来た人たちはそそくさと帰っていった
train, hypothesis:営業に来た人たちはそそくさと帰っていった
train, reference :彼らの入り口は特定の天体に向けられていると一般に信じられています
train, hypothesis:彼らの入り口は特定の天体に向けられていると一般に信じられています
train, reference :彼はルンド大学とアイスランド大学で学びました
train, hypothesis:彼はルンド大学とアイスランド大学で学びました
最終エポックの Validation データ推論
validation, reference :盗みをするよりも飢え死にした方がましだ
validation, hypothesis:予をするよりも上にしたほうが増しだ
validation, reference :盗まれた宝石はどんなことがあっても取り戻さなければならない
validation, hypothesis:無まれた方石はどんなことがあっても取り戻さなければならない
validation, reference :洞窟から大きな猿が現れると彼はびっくりしてにげていった
validation, hypothesis:どくから大きな砂ルが現れると彼はビっくりして見えてい行った
validation, reference :奥様にどうかよろしくお伝えください
validation, hypothesis:ごさまにどうかよろしくお伝いください
最終表示
Final epoch model -> ./exp_train_large/char_model_wav2vec2.0_069_2/final_model_ft_data0.pt
train loss: 0.008019
train token error rate: 0.118008 %
validation loss: 1.546401
validation token error rate: 22.111785 %
Best epoch model (92-th epoch) -> ./exp_train_large/char_model_wav2vec2.0_069_2/best_model_ft_data0.pt
train loss: 0.007814
train token error rate: 0.109816 %
validation loss: 1.562380
validation token error rate: 22.423783 %
プログラムの要点
Context Network 詳細
今回の wav2vec2.0 では、Context Network を Transformer Encoder + Downsampler + Transformer Decorder とした。Transformer Encoder と Decoder の Position Wise Feed Forward Network の代わりに使用した2層のconvolution layer は、Convolution + Relu + Convolution + Dropout + LayerNorm とした。Convolution の channel 数は、一層目が入力 512、出力2048。二層目が入力 2048 の出力 512 である。また、 kernel size は 5 と 1、Stride は 1 と 1、padding は (kernel_size - 1) // 2 である。Dropout は 0.1 である。Transfomer Encoder の隠れ層の次元は、512 であり、head 数は 8、レイヤー数は 6 である。Transformer Decoder の隠れ層の次元も 512 であり、head 数は8、レイヤー数は6である。
Transformer Encoder の出力と後述の Downsampler の出力を、Transformer Decoder の入力とした。
Feature Extractor
Feature Extractor の入力は wav スペクトルである。Feature Extractor は、7層の( Convolution + LayerNorm 相当の GroupNorm + GELU ) に加えて、LayerNorm + Linear + Dropout である。Convolution の [入力channel数、出力channel数]は、1層目が[1,512]で、2~7層は [512,512]である。kernel_sizeは、[10,3,3,3,3,2,2]、stride は [5,2,2,2,2,2,2]、padding はデフォルト、bias はすべて False である。Linear 層は、入力が512 で出力が 512 である。Dropout は0.1 である。
Gumbel Quantizer
Gumbel Quantizer module は入力の次元数が 512、codebook の数 G が 2、codebookのエントリー数 V が 320 である。モジュールの入力は、Feature Extractor の出力に mask をかけない計算量と Gumbel Softmax の温度である。 Gumbel Softmax の温度は、初期値が 2.0 で、1 step ごとに、0.999995 倍され、最低温度が 0.5 である。
Gumbel Quantizer では、Feature Extractor の出力を、入力が 512次元 で出力が G * V 次元の線形層に入力する。この線形層の出力が、Gumbel Softmax 関数に入力される logits となる。ここで、バッチサイズを bsz、時間方向の sequence 数を tsz とする。Gumbel Softmax の出力は ( bsz * tsz ) × ( G * V ) × 1 に整形される。一方、 1 × V × (512 / G) のニューラルネットワーク学習パラメータ の2次元目を G 回 repeat してテンソル 1 × ( G * V ) * ( 512 / G ) を作る。整形されたテンソル ( bsz * tsz ) × ( G * V ) × 1 と 1 × ( G * V ) × ( 512 / G ) のテンソルを掛け合わせる。この計算結果は、( bsz * tsz ) * ( G * V ) * ( 512/ G ) である。このテンソルを (bsz * tsz) × G × V × ( 512 / G ) に整形し、V について和をとる。これを、bsz × tsz × 512 に整形し出力とする。このテンソルを quantized voctor と呼ぶことにする。Gumbel Quantizer のもう一つの出力は、 最初の線形層の出力 logits の bsz と tsz についての平均を softmax 関数で確率とした値である。この計算結果を G × V に整形し、 pgv_bar と呼ぶことにする。Gumbel Vector Quantizer は、
のページを参考にした。
Mask
Encoder の入力は、事前学習の場合、Feature Extractor の出力に論文の p = 0.065 consec = 10 でマスクをかけた計算量である。Fine Tuning の場合は、Feature Extractor の出力にマスクをかけない計算量である。time_masking 関数は、
のページを参考にした。
Encoder
Encoder の最初の部分は、Positional Convolution Embedding である。Positional Convolution Emgedding は、transpose と Convolution と GELU よりなる。Convolution は、入力 channel 数が 512、出力 channel 数が 512 で、kernel size が 128、padding が 64、groups が 16、 stride はデフォルトである。Convolution で、kernel = 128, padding =64, Groups = 16, stride = 1 だと出力の時間方向の sequence 数が 1 増えるので、gelu の入力で x[:,:-1,:] のように調整する。 positional Convolution Embedding の出力は、Positional Convolution Embedding の入力と和がとられ LayerNorm に入力される。LayerNorm の出力が Transformer Encoder に入力される。
Downsampler
Downsampler の入力は、エンコーダー出力、quantized vector, mask、および各uterance の時間方向の sequence の長さである。エンコーダー出力、quantized vector, mask を、transpose と torch.nn.functional.interpolate (オプションmode='nearest-exact')を使って時間方向にダウンサンプリングする。各 uterance をダウンサンプリングした際の時間方向の sequence の長さも出力する。dwonsample の割合は 0.25 とした。
Decoder
Transformer Decoder の入力は、source input が Encoder の出力で、target input が Encoder の出力に対応した Downsampler の出力を position embeding した計算量である。Transformer Decoder の出力を GELU と LayerNorm を経たのち、線形層(入力 512, 出力512)を通して model 出力とされる。最終の線形層は、Fine Tuning の時に、この層を(512, num_tokens ) の線形層に置き換えるためにある。ここで、num_tokens は、tokenの語彙数である。
Lm loss
Lm の計算は、最初に示した論文に基づき行った。しかし、ダウンサンプリングしているので、negative similarity については、論文にあるように K = 100 のサンプリングができず、バッチの中の uterance の時間方向の最小 sequence 数でサンプリングした。Lm 計算の時に用いる温度は 0.1 を用いた。
Ld loss
pgv_bar から論文の通り計算した。
Loss
事前学習の損失、Loss = Lm + α Ld ただし、α = 100 で学習を行った。Wav2vec2Loss クラスは、
のページを参考にした。
Fine Tuning の特記事項
Fine Tuning では、Encoder の入力にマスクをかけないのでマスクの計算は不要。Lm は計算しないので、quantized vector の計算も不要。Ld の計算も不要。CTCLoss を計算して、backward をかける。CTCLoss の計算の時には、事前学習の時と model の最終の線形層の出力次元を修正する。実際には、出力次元数を Transformer Decoder の隠れ層の次元数から token の語彙数に修正する。この修正した出力を Context Network の推論出力とする。CTCLoss の計算には、model の推論出力に加えて、推論出力の時間方向の sequence 数が必要である。このため、Feature Exatractor の出力の時間方向の sequence 数を計算し、これを Downsampler でダウンサンプリングした場合の時間方向の sequence 数を計算する。この長さが、model 出力の時間方向の sequence 数になる。
分かったこと
Fine Tuning の結果、訓練データについては、Loss も0.005 と十分下がり、百分率である Token Error Rate も5.5%となり十分な学習ができたと考える。実際に推論した結果も学習できていると捉えることができる。しかし、評価データについては、Loss が 3.2程度であり、Token Error Rate も41% であり、推論した結果も学習しているとは考えられるが、十分とは言えない。これは、事前学習のデータが正味 50時間だからではないだろうか。マシンリソースが取れれば、オリジナルの論文にあるように、Libri Speech の 960 時間の音声データを事前学習に用いたい。しかし、RTX A6000 一枚では、学習時間が 500時間程度と予想され、簡単にはできない。また、メモリーの問題もあり、データがメモリに乗らないかもしれない。RTX A6000 一枚で、実際に計算を行うのは、現実的ではないのでないだろうか。
2024年2月14日追記
今まで、通常の学習で Train データの正解率が良いのに validation データの正解率が上がらないのは、Train データの数が足りないという考えのもとに上記考察を考えました。しかし、Python で学ぶ画像認識という本を読んでいたら、ResNet18 の解説で、Train データと Test データの正解率の差が30 % 程度あり、まだ過学習が起きているときの対処法として ResNet152 などのより層数の多いモデルを使うと良いと書かれています。後述のように Libri clean 360hours データで事前学習を行い、10,000 データの Fine Tuning を行った場合、17.9 % という Token Error Rate なので、この Error Rate を下げるためには、Encoder と Decoder の layer 数をともに 12 とし、隠れ層の次元数を 1024 としてみるのが良いのではと思います。しかし、360 hours データで、それだけの大きなネットワークを RTX-A6000 一枚で事前学習させられるかは未知数です。
プログラム
参考のために使ったプログラムを github にアップしておきます。
Context Network を Transformer Encoder だけにしたら
Context Network を Transformer Encoder + Downsampler + Transformer Decoder ではなく、12層の Transformer Encoder のみにして、事前学習と Fine Tuning を行った結果を報告する。
事前学習
Loss の推移
Similarity の推移
Lm の推移
Ld の推移
Learning Rate の推移
Fine Tuning
Loss の推移
Token Error の推移
Learning Rate の推移
推論結果
Fine Tuning 最終エポックの訓練データ推論
train, reference :出現する兆しとして霊の周りを靄が覆い始める
train, hypothesis:ス人するがしとしてレのわりをもがおいじめるる
train, reference :両者の間にはあったとしても相違はごくわずかである
train, hypothesis:料者のあにはあったとしてもそはごくわかである
train, reference :簡易ベッドの組み立て方を教えてもらえますか
train, hypothesis:感ブットのみだてかたをおえてもらいますか
train, reference :垂直落下式は危険性が大きいが反り投げ式は見た目が派手でなおかつ安全であるなどの理由からこの方式が広く普及したといわれている
train, hypothesis:水つらきは見生がをきーがそのきは見ためがででなーかつ全であるなどのびうからこの方きがレく不計したといわれている
Fine Tuning 最終エポックの評価データの推論
validation, reference :布を斜めに裁ちなさい
validation, hypothesis:のめに立ちなさい
validation, reference :豹はその斑点を変えることはできない
validation, hypothesis:長はその本手をを会えることはできない
validation, reference :筆者はそうした風潮を好まない
validation, hypothesis:1者はそうしたフい長をこのまない
validation, reference :飛行機は瞬く間に見えなくなった
validation, hypothesis:こう気はたくにえなくなった
最終表示
Final epoch model -> ./exp_train_large/char_model_wav2vec2.0_071/final_model_ft.pt
train loss: 2.829249
train token error rate: 42.779040 %
validation loss: 3.342297
validation token error rate: 45.837669 %
Best epoch model (65-th epoch) -> ./exp_train_large/char_model_wav2vec2.0_071/best_model_ft.pt
train loss: 2.821068
train token error rate: 42.788346 %
validation loss: 3.317758
validation token error rate: 45.889698 %
ダウンサンプリングしなかったら
Context Network でダウンサンプリングをせずに、Transformer Encoder + Transformer Decoder としたらどうなるか計算してみました。Transformer Decoder の source input も target input も Transformer Encoder の出力です。
事前学習
Loss の推移
Similarity の推移
Lm の推移
Ld の推移
Learning Rate の推移
Fine Tuning
Loss の推移
Token Error の推移
Learning Rate の推移
推論結果
最終エポックの訓練データの推論結果
train, reference :また軍事費の拡大によって市場に資本を投入し経済成長を促すため軍拡競争が激化することも考えられるからである
train, hypothesis:またク時のか大によって要にをとにしい全在生長をがすため分長そがでかすることも関えられるからであ
train, reference :この覆しようのない事実がどんなに美しくて優しそうに見えてもベストラは所詮は闇の種族だという観念を彼に植え付けた
train, hypothesis:この前し要のな時がどなにくすてしとに見てもエトはそ性はみのだという間年を彼にきた
train, reference :突如パニスの命を狙って現れた全身紫の衣装を纏いアイマスクを付けた女性
train, hypothesis:と女何スののをらってらわれた全しらさきの生をまといいますをつてた女う
train, reference :今法律では女性が男性と同一条件で雇用されることが求められている
train, hypothesis:今ほりでは生が断生とと気上けでこ用されることがけめられてる
train loss: 3.253787
train token error rate: 50.979830 %
最終エポックの validation データの推論結果
validation, reference :布を斜めに裁ちなさい
validation, hypothesis:のをにたちまさ
validation, reference :豹はその斑点を変えることはできない
validation, hypothesis:教はその入てををかえることはできな
validation, reference :筆者はそうした風潮を好まない
validation, hypothesis:1者はそしたフいをのまない
validation, reference :飛行機は瞬く間に見えなくなった
validation, hypothesis:行はたくなに見えなくなった
validation loss: 3.619210
validation token error rate: 52.185224 %
Early stopping. (early_stop_threshold = 3)
最終表示
---------------Summary------------------
Final epoch model -> ./exp_train_large/char_model_wav2vec2.0_072/final_model_ft.pt
train loss: 3.253787
train token error rate: 50.979830 %
validation loss: 3.619210
validation token error rate: 52.185224 %
Best epoch model (67-th epoch) -> ./exp_train_large/char_model_wav2vec2.0_072/best_model_ft.pt
train loss: 3.084257
train token error rate: 48.374964 %
validation loss: 3.556453
validation token error rate: 51.925078 %
Libri で学習
Libri clean 360 で事前学習
Libri clean 360hours のデータで事前学習しました。日本語と違い downsampling rate を 0.70 にしました。その他は、JSUT + Common Voice Ja の場合と同じパラメータです。学習の終了は0.5秒程度の停電でマシンが止まりました。
Libri clean 100 の 1万データで Fine Tuning
Libri clean 100 のデータのうち、1万件のデータで Fine Tuning を行いました。Token は character です。
途中で Validation Loss が上昇していますが、Validation Token Error Rate が減少していたので学習を続行しました。
Validation Token Error Rate の推移
17.9%台まで減少しました。
推論結果
最終エポックの Train データの推論結果
train, reference :when he came to himself natasha that same living natasha whom of all people he most longed to love with this new pure divine love that had been revealed to him was kneeling before him
train, hypothesis:when he came to himself natasha that same living natasha whom of all people he most longed to love with this new pure divine love that had been revealed to him was kneeling before him
train, reference :and endeavoured to comfort them they treated them as well as their poverty would permit took off their horse hair shifts which were very uneasy to them and put on them others which they gave them with shoes and something to cover their heads and save their hair
train, hypothesis:and endeavoured to comfort them they treated them as well as their poverty would permit took off their horse hair shifts which were very uneasy to them and put on them others which they gave them with shoes and something to cover their heads and save their hair
train, reference :or without speaking to the procureur well i have business with the procureur is it pressing business you can imagine so since i have not even brought my carriage out yet but enough of this here is my card
train, hypothesis:or without speaking to the procureur well i have business with the procureur is it pressing business you can imagine so since i have not even brought my carriage out yet but enough of this here is my card
train, reference :the knave shall die the law hath said while it protects our own slave trade what earthly eye presumes to scan the wily proteus heart of man what potent hand will e'er unroll
train, hypothesis:the knave shall die the law hath said while it protects our own slave trade what earthly eye presumes to scan the wily proteus heart of man what potent hand will e'er unroll
train loss: 0.002890
train token error rate: 0.054161 %
最終エポックの validation data の推論結果
validation, reference :pop goes right on tuning his channel
validation, hypothesis:popgos right ontorning is canno
validation, reference :you're getting altogether too upset about these programs stop it and behave yourself
validation, hypothesis:you getning altogether to ops sait about the's probrams stopt it an the heniorself
validation, reference :it's your fault mop it up yourself
validation, hypothesis:bit youre fall mobed at yourself
validation, reference :i hear the t v going for a few minutes then pop turns it off and goes in the kitchen to talk to mom
validation, hypothesis:i heare the tee by going for af houw minutes then poped hurmed hat off and dose in the kitch en to topk the mon
validation loss: 1.420961
validation token error rate: 17.984905 %
最終表示
---------------Summary------------------
Final epoch model -> ./exp_train_large_360/char_model_wav2vec2.0_078/final_model_ft.pt
train loss: 0.002890
train token error rate: 0.054161 %
validation loss: 1.420961
validation token error rate: 17.984905 %
Best epoch model (100-th epoch) -> ./exp_train_large_360/char_model_wav2vec2.0_078/best_model_ft.pt
train loss: 0.002833
train token error rate: 0.054216 %
validation loss: 1.424611
validation token error rate: 18.020846 %
再学習
停電で事前学習が途中停止したので、事前学習と Fine Tuning をやり直した。その結果、Validation Token Error Rate が、16.7 % まで減少した。
最終エポックのTrain データの推論
train, reference :where no soul could either see them or know who they were they were to enter the lists four several times those who were so happy as to conquer four competitors were afterwards to engage each other in single combat
train, hypothesis:where no soul could either see them or know who they were they were to enter the lists four several times those who were so happy as to conquer four competitors were afterwards to engage each other in single combat
train, reference :and then to banish us out of syria for ever but how unworthy soever our usage has been
train, hypothesis:and then to banish us out of syria for ever but how unworthy soever our usage has been
train, reference :he was a huge fellow terribly scarred about the face and chest and with one broken tusk and a missing ear
train, hypothesis:he was a huge fellow terribly scarred about the face and chest and with one broken tusk and a missing ear
train, reference :had really not made in her own country the mark she had chalked so large in london
train, hypothesis:had really not made in her own country the mark she had chalked so large in london
train loss: 0.002554
train token error rate: 0.046643 %
最終エポックの Validation データ推論
validation, reference :pop goes right on tuning his channel
validation, hypothesis:pup gos ried on toning is chaw
validation, reference :you're getting altogether too upset about these programs stop it and behave yourself
validation, hypothesis:you geing altogether to f set about theis probrams stop in and the hed yorself
validation, reference :it's your fault mop it up yourself
validation, hypothesis:its your fallt mobbit uf yourself
validation, reference :i hear the t v going for a few minutes then pop turns it off and goes in the kitchen to talk to mom
validation, hypothesis:i hear the peav bye going for a hew minutes than popped uirned that of and ghose in the katchen to top the mom
validation loss: 1.327640
validation token error rate: 16.784473 %
最終表示
---------------Summary------------------
Final epoch model -> ./exp_train_large_360/char_model_wav2vec2.0_086/final_model_ft.pt
train loss: 0.002554
train token error rate: 0.046643 %
validation loss: 1.327640
validation token error rate: 16.784473 %
Best epoch model (91-th epoch) -> ./exp_train_large_360/char_model_wav2vec2.0_086/best_model_ft.pt
train loss: 0.002552
train token error rate: 0.047411 %
validation loss: 1.326912
validation token error rate: 16.714987 %