LoginSignup
0
1

OpenAI Whisper のニューラルネットワークパラメータを初期化して学習させる試み。

Last updated at Posted at 2023-04-21

実験を行った動機。

以前に、OpenAI の Whisper を Fine Tuning するプログラムを情報共有

させていただきました。以前から、TensorFlow の音声認識

のページなどを勉強して、なんとか自分でも Transformer を使った音声認識のニューラルネットワークに学習させられないかと考えていました。TensorFlowのモデルでは、ようやく、音声特徴量を mfcc にすることにより、少し学習するかなという手ごたえがありましたが、

メルスペクトルでは、訓練用データの学習は進んでも、評価用データの loss が上がってしまうという状態が続いていました。どちらも、モデルを自分なりに改修しながら学習させていました。Whisper では、Fine Tuning でうまくいきましたので、モデル自体に信頼性があります。そこで、Whisper のニューラルネットワークパラメータを初期化して学習させてみたらどうなるだろうと考えました。実際にこの実験をやってみました。

実験結果の概略

結果は、TensorFlow のメルスペクトルの場合と同様、訓練用データには学習の効果がみられるのですが、評価用データについては学習の効果が見られません。いわゆる過学習の状態です。以下、実験とプログラムについてご説明させていただきます。

実験の概略

実験データは、JSUT-ver1.1 の BASIC の 5000発話を使わせていただきました。Whisper は tiny モデルのニューラルネットワークパラメータを初期化したものを使いました。学習は 50 epochs の予定で行いましたが、Early Stopping のため、45 epochs の学習で終了しました。Early Stopping の条件は、「訓練データの損失が 4 epochs 進んでも改善されない」です。

実験結果

45 epochs 目の結果です。

epoch = 45
train target:布を斜めに裁ちなさい
train predic:前事斜めに裁ちなさいならかさいが

train target:豹はその斑点を変えることはできない
train predic:豹はその斑点を変えることはできないない

train target:筆者はそうした風潮を好まない
train predic:筆者はそうした風潮を好まないないがは�ないないは

train target:飛行機は瞬く間に見えなくなった
train predic:飛行機は瞬く間に見えなくなった

train target:彼女もう浮かれちゃってるよ
train predic:彼女もう浮かれちゃってるよン�か

train target:彼女は浮気な女で本当に誰でも相手にする
train predic:コ�女は浮気な女で本当に誰でも相手にする立

train target:彼女は夫の到着を待ち焦がれています
train predic:彼女は夫の到着を待ち焦がれています

train target:彼女は恥じらいの色を隠すために顔をそむけた
train predic:彬女は恥じらいの色を隠すために顔をそむけただだ


val target:彼女は私を見るや否やわっと泣き出した
val predic:コ�定同�の到�わらちの恘しにだ�

val target:彼女は三味線による新しいジャズの演奏法を始めた
val predic:三�は目�のじでをめてしいらしい�れてをこ変ってもででいた

val target:彼女はいつも床を綺麗に掃いています
val predic:よ�多はかの�台にれてのんできちにちなです�

val target:彼女の伯母は一日中彼の犬の世話をする
val predic:病��は恮を�もう線り気なくが私��先ってがんで定

val target:彼女のいうことは的外れである
val predic:誨夜目��をで�でのある�場

val target:彼女がこれほど自分勝手なのは嘆かわしい
val predic:��朝豆ク�長�に�にくしちし�た

val target:彼は緻密に立てた計画を実行した
val predic:彼を�僚着を�いしま々何してらしてでした

val target:彼は老けて見える
val predic:全が身もう�ら立よ�いの�にるる

val target:彼は大尉以上の者を全員招集した
val predic:彴は同頚��大く院か�こと要建したでしたしたした

Epoch [45/50],loss: 0.02704 acc: 26.77000, val-loss: 0.64102 val-acc: 0.71311  cer: 121.43283 wer: 156.98481 lr: 3.961549438922103e-05
Early Stopping.

train target と train predic は、訓練用教師データと訓練用データのモデルによる予測です。同様に、val target と varl predic は、評価用教師データと評価用データのモデルによる予測です。これらからわかるように、訓練データ train についてはかなり学習の効果が見られます。一方、評価用データ val については、ほとんど学習の効果が見られません。

損失の学習曲線を見てみます。学習曲線も訓練データについては、0.8 から 0.1 以下へと改善していますが、評価データについては、5 epoch 程度までは改善しているものの、それ以降は改善していません。いわゆる過学習です。

fig1__initialize25-tiny-5000-45.png

次に、プログラムの情報共有をお願いいたします。

15000 データで訓練するプログラムgithub

2023年9月18日、初校からかなりたってから、15000データで訓練するプログラムを github にアップしました。
https://github.com/toshiouchi/whisper_initialize

ライブラリー

import torch
import librosa
import whisper
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import japanize_matplotlib
import evaluate
import gc
import spacy
import ginza
from tqdm.notebook import tqdm
import os
# ファイルをダウンロードするためのモジュールをインポート
from urllib.request import urlretrieve
# zipファイルを展開するためのモジュールをインポート
import zipfile
# yamlデータを読み込むためのモジュールをインポート
import yaml
import math
from glob import glob
from sklearn.model_selection import train_test_split
# サンプリング周波数を変換するためのモジュール(sox)をインポート
import sox
import os

#GPUの利用
device = 0 if torch.cuda.is_available() else "cpu"
print(device)

表示関連の設定

# warning表示off
import warnings
warnings.simplefilter('ignore')

# デフォルトフォントサイズ変更
plt.rcParams['font.size'] = 14

# デフォルトグラフサイズ変更
plt.rcParams['figure.figsize'] = (6,6)

# デフォルトで方眼表示ON
plt.rcParams['axes.grid'] = True

# numpyの表示桁数設定
np.set_printoptions(suppress=True, precision=5)

データ置き場の定義

# データの置き場を定義
data_dir = './data/original'
# ディレクトリdata_dirが存在しない場合は作成する
os.makedirs(data_dir, exist_ok=True)

JSUT コーパスのダウンロードなど

# JSUTコーパスのダウンロード、解凍
data_archive = os.path.join(data_dir, 'jsut-data.zip')
urlretrieve('http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip', data_archive)
with zipfile.ZipFile(data_archive) as data_zip:
        data_zip.extractall(data_dir)
# zipファイルを削除する
os.remove(data_archive)

ラベルデータのダウンロードなど

# ラベルデータのダウンロード、解凍
label_archive = os.path.join(data_dir, 'jsut-label.zip')
urlretrieve('https://github.com/sarulab-speech/jsut-label/archive/master.zip',label_archive)
with zipfile.ZipFile(label_archive) as label_zip:
        label_zip.extractall(data_dir)
# zipファイルを削除する
os.remove(label_archive)

WAVファイルの16000Hzダウンサンプリング処理の設定

# WAVファイルの16000Hzダウンサンプリング処理の設定
# wavファイルが展開されたディレクトリ
original_wav_dir = './data/original/jsut_ver1.1/basic5000/wav'
# フォーマット変換したwavファイルを出力するディレクトリ
out_wav_dir = './data/wav'
# 出力ディレクトリが存在しない場合は作成する
os.makedirs(out_wav_dir, exist_ok=True)
# soxによる音声変換クラスを呼び出す
tfm = sox.Transformer()
# サンプリング周波数を 16000Hz に変換するよう設定する
tfm.convert(samplerate=16000)

wav ファイルを、実際に 16000 Hz にダウンサンプリングして書き込み

#wav ファイルを、実際に 16000 Hz にダウンサンプリングして書き込み
# BASIC5000_0001.wav ~ BASIC5000_5000.wav に対して処理を繰り返し実行
for i in range(5000):
    filename = 'BASIC5000_%04d' % (i+1)
    # 変換元のオリジナルデータ (48000Hz)のファイル名
    wav_path_in = os.path.join(original_wav_dir, filename+'.wav')
    # 変換後のデータ(16000Hz)の保存ファイル名
    wav_path_out = os.path.join(out_wav_dir, filename+'.wav')

    # ファイルが存在しない場合はエラー
    if not os.path.exists(wav_path_in):
        print('Error: Not found %s' % (wav_path_in))
        exit()

    # サンプリング周波数の変換と保存を実行する
    tfm.build_file(input_filepath=wav_path_in, 
                   output_filepath=wav_path_out)

metadata.csv の作成

# metadata.csv の作成
original_wav_dir = './data/original/jsut_ver1.1/basic5000/wav/'
out_wav_dir = "./data/wav"
original_label = './data/original/jsut-label-master/text_kana/basic5000.yaml'
metadata_csv_dir = "./data/"

with open(original_label, mode='r', encoding="utf-8") as yamlfile:
        label_info = yaml.safe_load(yamlfile)

metadata_csv_path = metadata_csv_dir + f'metadata.csv'
#metadata_csv_path = os.path.join( original_wav_dir, 'metadata.csv' )
        
file = open( metadata_csv_path, mode="w", encoding="utf-8" )

for i in range(5000):
    # 発話ID
    filename = 'BASIC5000_%04d' % (i+1)
    #filename1 = original_wav_dir + f'{filename}.wav'
    filename1 = out_wav_dir + f'{filename}.wav'

    # 発話ID が label_info に含まれない場合はエラー
    if not filename in label_info:
        print('Error: %s is not in %s' % (filename, original_label))
        exit()

    # キャラクターラベル情報を取得
    chars = label_info[filename]['text_level2']
    # '、'と'。'を除去
    #chars = chars.replace('、', '')
    #chars = chars.replace('。', '')
    
    #str1 = filename1 + ',' + chars + "\n"
    str1 = filename + '|' + chars + "\n"

    #file.write( str1 )
    #file.write('%s %s\n' % (filename, ' '.join(chars)))
    file.write('%s|%s\n' % (filename, chars))

file.close()

データの読み込み

fr = open( metadata_csv_path, "r", encoding='UTF-8')
#fr = open('./data/metadata4.csv', "r", encoding='UTF-8')

datalist = fr.readlines()
audio_train = []
sentence_train = []
filename_train = []
audio_test = []
sentence_test = []
filename_test = []
for i, line1 in enumerate(datalist):
    filename = line1.split( '|' )[0]
    filename = filename.split('/')[-1]
    filename = "./data/wav/" + filename + ".wav"
    #filename = "./datasets/jsut_ver1.1/wavs/" + filename + ".wav"
    print( "i:{}, filename:{}".format( i, filename ))
    #音声ファイルの読み込み
    #filename = "./datasets/jsut_ver1.1/wavs/" + filename + ".wav"
    audio0, _ = librosa.load( filename, sr = 16000)
    #line2 = line1.split( '|' )[1]
    line2 = line1.split( '|' )[1].replace( "\n", "" )
    line2 = line2.replace( "", "" )
    line2 = line2.replace( "", "" )
    print( "line2:{}".format(line2 ))
    #if i < 5000:                                        #train データの数に合わせる
    if i < 4900:
    #if i < 90:
        #音声データ
        audio_train.append( audio0 )
        #正解文章データ
        sentence_train.append( line2 )
        #音声ファイル名データ
        filename_train.append( filename )
    else:                                               # validation データになる
    #elif i < 100:
        audio_test.append( audio0 )
        #音声ファイル名データ
        filename_test.append( filename )
        #正解文章データ
        sentence_test.append( line2 )
    #else:
    #    break
        
print( len( audio_train ))
print( len( sentence_train))
print( len( filename_train))
print( len( audio_test))
print( len( sentence_test ))
print( len( filename_test))

句読点は削除しました。Fine Tuning では、句読点があっても大丈夫でした。

whisper モデルの読み込みなど

#whisper モデルの読み込み
model = whisper.load_model( 'tiny' )

#tokenizer の読み込みと定数の設定
tokenizer = whisper.tokenizer.get_tokenizer(True, task="transcribe", language="ja")
eos = tokenizer.eot
special_tokens1 = list( tokenizer.sot_sequence_including_notimestamps )

モデルパラメーターの初期化

この部分で、ニューラルネットワークパラメーターの初期化を行っています。

# モデルの重みとバイアス(パラメータ)の初期化

def kaiming_init(model):
    """ Kaimingの初期化

    Args:
        model (object): モデル
    """
    i = 1
    for name, param in model.named_parameters():
        print("i:{}".format( i ))
        i += 1
        if name.endswith(".bias"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "bias")
            param.data.fill_(0)
        elif name.startswith("encoder.conv"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "encoder.conv")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))
            #param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
        elif name.endswith("ln.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print("ln.weight")
            param.data.fill_(1)
        elif name.endswith("ln_post.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "ln_post.weight")
            param.data.fill_(1)
        elif name.startswith("layers.0"):
             # The first layer does not have ReLU applied on its input
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "layers.0")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))
        elif name.endswith(".mlp.0.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print(".mlp.0.weight" )
            param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
        else:
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print("other")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))       
            
kaiming_init(model)

この部分を行わなければ、Fine Tuning にも応用がきくと思います。

バッチに分割

# データをバッチに分割する。
batch_size = 8                                            #学習、バリデーション共通の batch_size メモリー使用量を見ながら調整してください。

input_ids_train = audio_train
#文章データを tokenizer で、token に変換
#labels_train = tokenizer(sentence_train).input_ids
labels_train = []
for sentence1 in sentence_train:
    labels_train.append( tokenizer.encode(sentence1) )
input_ids_test = audio_test
#labels_test = tokenizer(sentence_test).input_ids
labels_test = []
for sentence1 in sentence_test:
    labels_test.append( tokenizer.encode(sentence1) )

#バッチ分割
split = len(input_ids_train) // batch_size
#音声データ train 用
batch_input_ids_train = np.array_split( input_ids_train, split )
print( len(batch_input_ids_train))
print( len(batch_input_ids_train[0]))
print( len(batch_input_ids_train[0][0]))
# 正解 token データ、target, y train 用。
batch_labels_train = np.array_split( labels_train, split )
print( len(batch_labels_train))
print( len(batch_labels_train[0]))
print( len(batch_labels_train[0][0]))
split = len(filename_test) // batch_size
#音声ファイル名データ validation 用
batch_filename_test = np.array_split( filename_test, split )
print( len(batch_filename_test ))
batch_input_ids_test = np.array_split( input_ids_test, split )
print( len(batch_input_ids_test ))
#正解文章データ validation 用
batch_sentence_test = np.array_split( sentence_test, split )
print( len(batch_sentence_test ))
batch_labels_test = np.array_split( labels_test, split )
print( len(batch_labels_test ))

バッチデータの確認

# バッチデータの確認
for i, _ in enumerate( range( len( batch_sentence_test ) ) ):
    print( len(batch_filename_test[i] ))
    print( len(batch_sentence_test[i] ))
#サンプルデータの確認
print( len(batch_input_ids_train[10][0] ) )
print( len(batch_labels_train[10][0] ) )

ニューラルネットワークの概要表示

# ネットワークの概要表示
print(model)

学習の準備

# GPUの利用
model = model.to(device)

# 学習率
lr = 1e-2
#lr = 0.0004
#lr = 1e-1
#lr = 1.6e-5
#lr = 3.2e-6

# 損失関数定義
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# エポック数
num_epochs1 = 50

# 最適化関数定義
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr = lr )
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0008, 
                                                steps_per_epoch=len(batch_input_ids_train), epochs=num_epochs1)
lrs = []

# historyファイルと torch.save ファイルの path を初期化する
history = np.zeros((0, 8))
paths = []

to_pad_to_mel

#音声データ を 30秒に pad or trim して、mel フィルターバンクにする。
def to_pad_to_mel(array):
    """Static function which:
        1. Pads/trims a list of audio arrays to a max length of 30s
        2. Computes log-mel filter coefficients from padded/trimmed audio sequences
        Inputs:
            array: list of audio arrays
        Returns:
            input_ids: torch.tensor of log-mel filter bank coefficients
    """
    padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
    input_ids = whisper.log_mel_spectrogram(padded_input)
    return input_ids

学習用関数

# 学習用関数
def fit(model, optimizer, criterion, num_epochs, 
        batch_input_ids_train, batch_labels_train,
        batch_filename_test, batch_sentence_test,
        batch_input_ids_test, batch_labels_test,
        device, lr, history, pahts):

    base_epochs = len(history)
    
    # cer, wer 計算ライブラリーの読み込み
    metrics_cer = evaluate.load("cer")
    metrics_wer = evaluate.load("wer")

    # wer 計算のために分かち書きをするので、分かち書き用。
    nlp = spacy.load("ja_ginza")
    ginza.set_split_mode(nlp, "C") # CはNEologdの意らしいです
    
    #メモリー確保
    gc.collect()
    
    # epoch のループ
    for epoch in range(base_epochs, num_epochs+base_epochs):
        train_loss = 0
        train_acc = 0
        val_loss = 0
        val_acc = 0
        #torch.save するファイルのファイル名。
        path = "whisper-torch-initialize16-tiny-" + (str(epoch+1)) + ".pth"
        #print( "path:{}".format(path))
        
        #訓練フェーズ
        model.train()
        count = 0
        step = 0
        
        #バッチのループ、プログレスバー対応 
        phar = tqdm( range( len(batch_input_ids_train) ), desc='train' )
        for i in phar:
            step += 1 
            
            # model へ入力する x データ
            input_ids = batch_input_ids_train[i]
            # model に対して教師データとなる y, target データ
            labels = batch_labels_train[i]
            
            gc.collect()
            count += len(labels)
            # audio データを melフィルタバンクに変換し、padding する。 (batch_size, 80, 30000)
            input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids]).to(device)
            #教師データから、decoder の入力に使うデータを作るために label0 をとっておく。
            labels0 = labels
            #教師データ
            labels = [ special_tokens1[:1] + lab + [eos] for lab in labels]
            #デコーダーへ力
            labels2 = [ special_tokens1 + lab2 +[eos] for lab2 in labels0 ]
            # finally, pad the target labels to max_len
            label_lengths = [len(lab)  for lab in labels]
            label2_lengths = [len(lab2) for lab2 in labels2]
            max_label_len = max(label_lengths)
            max_label2_len = max(label2_lengths)
            labels = [np.pad(lab, (0, max_label2_len - lab_len ), 'constant', 
                      constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
            #教師データ
            labels = torch.tensor( np.array(labels), requires_grad=False, dtype=torch.long).to(device)
            dec_input_ids = [np.pad(lab2, (0, max_label2_len - lab2_len ), 'constant', 
                     constant_values=eos) for lab2, lab2_len in zip(labels2, label2_lengths)]
            #デコーダーへ入力
            dec_input_ids = torch.tensor( np.array(dec_input_ids), requires_grad=False, dtype=torch.long).to(device)
     
            del labels0
            del labels2
            del label_lengths
            del label2_lengths
            del max_label_len
            del max_label2_len
            gc.collect()
            
            # 勾配の初期化
            optimizer.zero_grad()

            
            # 予測計算
            logits = model(input_ids, tokens=dec_input_ids )
            
            # 損失計算
            loss = criterion( logits.view(-1, logits.size(-1)), labels.view(-1) )
           
            train_loss += loss.item()

            # 勾配計算
            loss.backward()

            # パラメータ修正
            optimizer.step()
            lrs.append(optimizer.param_groups[0]["lr"])
            current_lr = optimizer.param_groups[0]["lr"]
            #     print("Factor = ",i," , Learning Rate = ",optimizer.param_groups[0]["lr"])
            scheduler.step()

            # 予測値算出
            predicted = torch.max(logits, 2)[1]
    
            # 毎 epoch において、 len(batch_input_ids_train)番目のバッチの結果を表示。
            if step == len(batch_input_ids_train):
                labels[labels == -100] = eos
                predicted[predicted == -100] = eos
                
                #label_str = tokenizer.batch_decode( labels, skip_special_tokens=True )
                #pred_str = tokenizer.batch_decode( predicted, skip_special_tokens=True )

                label_str = []
                pred_str = []
            
                for labels1, predicted1 in zip( labels, predicted ):
                    #label_str.append( tokenizer.decode( labels1, skip_special_tokens=True ) )
                    #pred_str.append( tokenizer.decode( predicted1, skip_special_tokens=True ) )
                    label_str.append( tokenizer.decode( labels1 ).replace( "<|endoftext|>", "" ).replace( "<|ja|>", "" )
                                    .replace( "<|transcribe|>","").replace("<|startoftranscript|>","").replace("<|notimestamps|>",""))
                    pred_str.append( tokenizer.decode( predicted1 ).replace( "<|endoftext|>", "" ).replace( "<|ja|>", "" )
                                    .replace( "<|transcribe|>","").replace("<|startoftranscript|>","").replace("<|notimestamps|>",""))                
                
                
                for j, _ in enumerate( label_str ):
                    print( "train target:{}".format( label_str[j] ))
                    print( "train predec:{}\n".format( pred_str[j] ))            
            
            
            # 正解件数算出
            train_acc += (predicted == labels).sum()

            del logits
            del loss
            del predicted
            gc.collect()

            # 損失と精度の計算
            avg_train_loss = train_loss / count
            avg_train_acc = train_acc / count

            #プログラスバーに loss 表示
            phar.set_postfix( loss = avg_train_loss )            
        
        #予測フェーズ
        model.eval()
        count = 0
        step = 0
        cer_sum2 = 0
        wer_sum2 = 0

        #バッチのループ、プログレスバー対応
        phar = tqdm( range( len(batch_filename_test) ), desc='val' )        
        for i in phar:
            #print( "i:{}".format( i ))
            #音声データのファイル名,x
            filenames = batch_filename_test[i]
            #正解データ。target, y
            labels00 = batch_sentence_test[i]
            gc.collect()
            count += len(batch_filename_test)
            step += 1 

            gc.collect()
            
            # 予測計算 本番用の model.transcribe 関数を用いて予測するため、val_loss と val_acc は計算しない。
            #pred_str = []
            #for filename in filenames:
            #    result, _ = model.transcribe(filename, language="ja", task="transcribe" )
            #    pred_str.append( result['text'] )
            #label_str = labels
            # 最初のバッチだけ、例として正解文章と予測した文章を表示。
            #if step == 1:
            #    for j, _ in enumerate( label_str ):
            #        print( "target:{}".format( label_str[j] ))
            #        print( "predec:{}\n".format( pred_str[j] ))

            # model へ入力する x データ
            input_ids = batch_input_ids_test[i]
            # model に対して教師データとなる y, target データ
            labels = batch_labels_test[i]
            gc.collect()
            count += len(labels)
            # audio データを melフィルタバンクに変換し、padding する。 (batch_size, 80, 30000)
            input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids]).to(device)
            #教師データから、decoder の入力に使うデータを作るために label0 をとっておく。
            labels0 = labels
            #教師データ
            labels = [ special_tokens1[:1] + lab + [eos] for lab in labels]
            #デコーダーへ力
            labels2 = [ special_tokens1 + lab2 + [eos]  for lab2 in labels0 ]
            # finally, pad the target labels to max_len
            label_lengths = [len(lab) for lab in labels]
            label2_lengths = [len(lab2) for lab2 in labels2]
            max_label_len = max(label_lengths)
            max_label2_len = max(label2_lengths)
            labels = [np.pad(lab, (0, max_label2_len - lab_len ), 'constant', 
                      constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
            #教師データ
            labels = torch.tensor( np.array(labels), requires_grad=False, dtype=torch.long).to(device)
            dec_input_ids = [np.pad(lab2, (0, max_label2_len - lab2_len ), 'constant', 
                     constant_values=eos) for lab2, lab2_len in zip(labels2, label2_lengths)]
            #デコーダーへ入力
            dec_input_ids = torch.tensor( np.array(dec_input_ids), requires_grad=False, dtype=torch.long).to(device)                    
                    
            # 予測計算
            logits = model(input_ids, tokens=dec_input_ids )
                
            # 損失計算
            loss = criterion( logits.view(-1, logits.size(-1)), labels.view(-1) )
            val_loss += loss.item()
 
            # 予測値算出
            predicted = torch.max(logits, 2)[1]

            # 正解件数算出
            val_acc += (predicted == labels).sum()

            #if step == 1:
            #    for j, _ in enumerate( label_str ):
            #        print( "target:{}".format( label_str[j] ))
            #        print( "predec:{}\n".format( pred_str[j] ))
            
            label_str = labels00
            #pred_str = tokenizer.batch_decode( predicted, skip_special_tokens=True )
            pred_str = []
            
            for labels1, predicted1 in zip( labels, predicted ):
                #label_str.append( tokenizer.decode( labels1, skip_special_tokens=True ) )
                #pred_str.append( tokenizer.decode( predicted1, skip_special_tokens=True ) )
                pred_str.append( tokenizer.decode( predicted1 ).replace( "<|endoftext|>", "" ).replace( "<|ja|>", "" )
                                .replace( "<|transcribe|>","").replace("<|startoftranscript|>","").replace("<|notimestamps|>",""))    
            
            # cer 算出
            #print( "label_str:{}".format( label_str ))
            #print( "pred_str:{}".format( pred_str ))
            cer = 100 * metrics_cer.compute(predictions=pred_str, references=label_str)
            cer_sum2 += cer

            # wer 算出
            # 分かち書きして空白区切りに変換
            wer_pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
            wer_label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]            
            wer = 100 * metrics_wer.compute(predictions=wer_pred_str, references=wer_label_str)                
            wer_sum2 += wer                                 
            
            if step == 1:
                for j, _ in enumerate( label_str ):
                    print( "val target:{}".format( label_str[j] ))
                    print( "val predec:{}\n".format( pred_str[j] ))
            
            
            gc.collect()            
            
            # 損失と精度の計算ダミー
            avg_val_loss = val_loss / count
            avg_val_acc = val_acc / count
            # cer と wer の平均値計算。
            avg_cer = cer_sum2 / step
            avc_wer = wer_sum2 / step
            
            #プログラスバーに cer 表示
            phar.set_postfix( val_loss = avg_val_loss )   
        
        avg_cer = cer_sum2 / (step )
        #print("avg_cer:{}".format( avg_cer ) )
        avg_wer = wer_sum2 / (step )
        #print("avg_wer:{}".format( avg_wer ))
    
        print (f'Epoch [{(epoch+1)}/{num_epochs+base_epochs}],loss: {avg_train_loss:.5f} acc: {avg_train_acc:.5f}, val-loss: {avg_val_loss:.5f} val-acc: {avg_val_acc:.5f}  cer: {avg_cer:.5f} wer: {avg_wer:.5f} lr: {current_lr:}')
        item = np.array([epoch+1, avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc, avg_cer, avg_wer, current_lr])
        history = np.vstack((history, item))
        paths.append( path )
        
        # modelを torch.save する。
        #epoch が  1以上で前の epoch より val-los が小さければ save する。
        if epoch >=1:
            if history[epoch,3] <= history[epoch-1,3]:
                torch.save( model, path )
                print( "model {} is saved.".format(path)) 
        #eoich が 0 の時は無条件で save
        elif epoch == 0:
            torch.save( model, path )
            print( "model {} is saved.".format(path))

        # Early Stopping
        # 4 epoch 進んでも loss が改善しなかったら終了。
        if epoch >= 3:
            if history[epoch,1] >= history[epoch-3,1]:
                print( "Early Stopping.")
                break
        
    return history, pahts

メモリーの確保(おまじない)

gc.collect()

パラメータの学習を有効にする。念のため。

for param in model.parameters():
    param.requires_grad = True

学習

# 学習
num_epochs = num_epochs1
history, paths = fit(model, optimizer, criterion, num_epochs, 
            batch_input_ids_train, batch_labels_train,
            batch_filename_test, batch_sentence_test,
            batch_input_ids_test, batch_labels_test,
            device, lr, history, paths)

考察

結果は過学習です。5000件のデータでは足りないのでしょうか。汎化性能を上げるためにはデータ量を増やす必要があるのではないでしょうか。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1