6
4

More than 1 year has passed since last update.

OpenAI の Whisper を、自前の音声データで Fine Tuning するプログラム

Last updated at Posted at 2023-02-14

プログラムの目的

OpenAI の Whisper には、30秒以上の音声ファイルを文字起こしする transcribe 関数があります。驚異的なのは、large モデルで 10 分以上の音声ファイルから字幕ファイルを作っても、メモリーで問題が起きないです。fine tuning したモデルに、この機能を使うために、OpenAI の Whisper を fine tuning するプログラムを作成しました。fine tuning させたところ、きちんと学習しているので、情報の共有をお願いします。

ライブラリーの読み込みなど

最初に、ライブラリーの読み込みと GPU の判別。わたくしの開発環境は、GPU がないので、一応、GPU が使えるようにプログラムは書いたつもりですが、動作確認はしていません。もし、ちゃんと動かなかったら、直して使ってください。CPU では動作確認しました。

import torch
import librosa
import whisper
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import japanize_matplotlib
import evaluate
import gc
import spacy
import ginza
from tqdm.notebook import tqdm

#GPUの利用
device = 0 if torch.cuda.is_available() else "cpu"
print(device)

設定

次に、notebook の設定。赤石雅典さんの「最短でわかる PyTorch & 深層プログラミング」を参考にさせていただきました。fit 関数、evaluate_history 関数なども参考にしています。

# warning表示off
import warnings
warnings.simplefilter('ignore')

# デフォルトフォントサイズ変更
plt.rcParams['font.size'] = 14

# デフォルトグラフサイズ変更
plt.rcParams['figure.figsize'] = (6,6)

# デフォルトで方眼表示ON
plt.rcParams['axes.grid'] = True

# numpyの表示桁数設定
np.set_printoptions(suppress=True, precision=5)

データ

次にデータの読み込み、データの配置は、

current dir------P
                 |
                 ----metadata.csv
                 |
                 ----data
                      | p0001.wav
                      | p0002.wav
                      | p0003.wav
            ・・・

です。metadata.csv の中身は、ファイル名の拡張子をのぞいた basename と教師データです。

p0001,教師データの文章1
p0002,教師データの文章2
p0003,教師データの文章3
・・・

のようです。データの読み込みプログラムは、

#データの読み込み

fr = open('./P/metadata.csv', "r", encoding='UTF-8')

datalist = fr.readlines()
sentence_train = []
filename_test = []
sentence_test = []
for i, line1 in enumerate(datalist):
    filename = line1.split( ',' )[0]
    #音声ファイルの読み込み
    filename = "./P/data/" + filename + ".wav"
    audio0, _ = librosa.load( filename, sr = 16000)
    line2 = line1.split( ',' )[1]
    if i < 280:                                        #train データの数に合わせる
        #音声データ
        audio_train.append( audio0 )
        #正解文章データ
        sentence_train.append( line2 )
    else:                                              # validation データになる
        #音声ファイル名データ
        filename_test.append( filename )
        #正解文章データ
        sentence_test.append( line2 )

print( len( audio_train ))
print( len( sentence_train))
print( len( filename_test))
print( len( sentence_test ))

です。ここで、train データの数を指定してください。

モデルの読み込み

次は、whisper モデルと tokenizer の読み込みおよび定数の設定です。

#whisper モデルの読み込み
model = whisper.load_model( 'small' )

#tokenizer の読み込みと定数の設定
whisper_tok = whisper.tokenizer.get_tokenizer(True, task="transcribe", language="ja")
tokenizer = whisper_tok.tokenizer
tokenizer.pad_token = tokenizer.eos_token
eos = tokenizer.eos_token_id
sot =  tokenizer("<|startoftranscript|>").input_ids[0]
lang_id = tokenizer("<|ja|>").input_ids[0]
task = tokenizer("<|transcribe|>").input_ids[0]
t_stamp = tokenizer("<|notimestamps|>").input_ids[0]

ここで、デコーダーの入力データの始まりが、sot, lang_id, task, t_stamp ・・・, eosで、教師データが、 lang_id, task, t_stamp, ・・・, eos であるというこは、

のページで認識しました。それまでは、単純に、<sos>・・・ <eos>だと思っていました。

 元とする whisper のモデルですが、small モデルでうまく学習できることが確認出来たら、large モデルで学習させてみることをお勧めします。修正は、読み込むモデルを large にすることと、メモリーを消費するため、batch_size を小さくすることくらいです。

バッチ

データをバッチに分割します。

# データをバッチに分割する。
batch_size = 2                                             #学習、バリデーション共通の batch_size メモリー使用量を見ながら調整してください。

input_ids_train = audio_train
#文章データを tokenizer で、token に変換
labels_train = tokenizer(sentence_train).input_ids

#バッチ分割
split = len(input_ids_train) // batch_size
#音声データ train 用
batch_input_ids_train = np.array_split( input_ids_train, split )
print( len(batch_input_ids_train))
print( len(batch_input_ids_train[0]))
print( len(batch_input_ids_train[0][0]))
# 正解 token データ、target, y train 用。
batch_labels_train = np.array_split( labels_train, split )
print( len(batch_labels_train))
print( len(batch_labels_train[0]))
print( len(batch_labels_train[0][0]))
split = len(filename_test) // batch_size
#音声ファイル名データ validation 用
batch_filename_test = np.array_split( filename_test, split )
print( len(batch_filename_test ))
#正解文章データ validation 用
batch_sentence_test = np.array_split( sentence_test, split )
print( len(batch_sentence_test ))

サンプルデータの確認

#サンプルデータの確認
print( len(batch_input_ids_train[10][0] ) )
print( len(batch_labels_train[10][0] ) )

モデルの確認

# ネットワークの概要表示

print(model)

summary の表示

#summary の表示

#!pip install torchsummaryX
#from torchinfo import summary
from torchsummaryX import summary

summary(model=model, x=torch.zeros((batch_size,80,3000)), tokens=torch.zeros((batch_size,448),dtype=torch.long))

学習にかかわる定数などの設定

# GPUの利用
model = model.to(device)

# 学習率
lr = 0.0001

# 損失関数定義
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# 最適化関数定義
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# historyファイルと torch.save ファイルの path を初期化する
history = np.zeros((0, 8))
paths = []

to_pad_to_mel 関数

音声データを 30秒に pad or trim し、log-mel フィルタバンク化する関数。学習用関数 fit で使います。

のページより引用。

#音声データ を 30秒に pad or trim して、mel フィルターバンクにする。
def to_pad_to_mel(array):
    """Static function which:
        1. Pads/trims a list of audio arrays to a max length of 30s
        2. Computes log-mel filter coefficients from padded/trimmed audio sequences
        Inputs:
            array: list of audio arrays
        Returns:
            input_ids: torch.tensor of log-mel filter bank coefficients
    """
    padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
    input_ids = whisper.log_mel_spectrogram(padded_input)
    return input_ids

学習用関数

中で、model( input_ids, tokens=dec_input_ids )としていますが、encoder と decoder が含まれています。

# 学習用関数
def fit(model, optimizer, criterion, num_epochs, 
        batch_input_ids_train, batch_labels_train,
        batch_filename_test, batch_sentence_test,
        device, lr, history, pahts):

    base_epochs = len(history)
    
    # cer, wer 計算ライブラリーの読み込み
    metrics_cer = evaluate.load("cer")
    metrics_wer = evaluate.load("wer")

    # wer 計算のために分かち書きをするので、分かち書き用。
    nlp = spacy.load("ja_ginza")
    ginza.set_split_mode(nlp, "C") # CはNEologdの意らしいです
    
    #メモリー確保
    gc.collect()
    
    # epoch のループ
    for epoch in range(base_epochs, num_epochs+base_epochs):
        train_loss = 0
        train_acc = 0
        val_loss = 0
        val_acc = 0
        #torch.save するファイルのファイル名。
        path = "whisper-torch-fine-tuning-" + (str(epoch+1)) + ".pth"
        #print( "path:{}".format(path))
        
        #訓練フェーズ
        model.train()
        count = 0
        
        #バッチのループ、プログレスバー対応 
        phar = tqdm( range( len(batch_input_ids_train) ), desc='train' )
        for i in phar:
            # model へ入力する x データ
            input_ids = batch_input_ids_train[i]
            # model に対して教師データとなる y, target データ
            labels = batch_labels_train[i]
            gc.collect()
            count += len(labels)
            # audio データを melフィルタバンクに変換し、padding する。 (batch_size, 80, 30000)
            input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids]).to(device)
            #教師データから、decoder の入力に使うデータを作るために label0 をとっておく。
            labels0 = labels
            #教師データ
            labels = [ [lang_id] + [task] + [t_stamp] + lab + [eos] for lab in labels]
            #デコーダーへ力
            labels2 = [ [sot] + [lang_id] + [task] + [t_stamp] + lab2 +[eos] for lab2 in labels0 ]
            # finally, pad the target labels to max_len
            label_lengths = [len(lab) for lab in labels]
            label2_lengths = [len(lab2) for lab2 in labels2]
            max_label_len = max(label_lengths)
            max_label2_len = max(label2_lengths)
            labels = [np.pad(lab, (0, max_label2_len - lab_len ), 'constant', 
                      constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
            #教師データ
            labels = torch.tensor( np.array(labels), requires_grad=False, dtype=torch.long).to(device)
            dec_input_ids = [np.pad(lab2, (0, max_label2_len - lab2_len ), 'constant', 
                     constant_values=eos) for lab2, lab2_len in zip(labels2, label2_lengths)]
            #デコーダーへ入力
            dec_input_ids = torch.tensor( np.array(dec_input_ids), requires_grad=False, dtype=torch.long).to(device)
     
            del labels0
            del labels2
            del label_lengths
            del label2_lengths
            del max_label_len
            del max_label2_len
            gc.collect()
            
            # 勾配の初期化
            optimizer.zero_grad()

            # 予測計算
            logits = model(input_ids, tokens=dec_input_ids )
            
            # 損失計算
            loss = criterion( logits.view(-1, logits.size(-1)), labels.view(-1) )
           
            train_loss += loss.item()

            # 勾配計算
            loss.backward()

            # パラメータ修正
            optimizer.step()

            # 予測値算出
            predicted = torch.max(logits, 2)[1]

            # 正解件数算出
            train_acc += (predicted == labels).sum()

            del logits
            del loss
            del predicted
            gc.collect()

            # 損失と精度の計算
            avg_train_loss = train_loss / count
            avg_train_acc = train_acc / count

            #プログレスバーに loss 表示
            phar.set_postfix( loss = avg_train_loss )            
        
        #予測フェーズ
        model.eval()
        count = 0
        step = 0
        cer_sum2 = 0
        wer_sum2 = 0

        #バッチのループ、プログレスバー対応
        phar = tqdm( range( len(batch_filename_test) ), desc='val' )        
        for i in phar:
            #print( "i:{}".format( i ))
            #音声データのファイル名,x
            filenames = batch_filename_test[i]
            #正解データ。target, y
            labels = batch_sentence_test[i]
            gc.collect()
            count += len(labels)
            step += 1 

            gc.collect()
            
            # 予測計算 本番用の model.transcribe 関数を用いて予測するため、val_loss と val_acc は計算しない。
            pred_str = []
            for filename in filenames:
                result = model.transcribe(filename, language="ja", task="transcribe")
                pred_str.append( result['text'] )
            label_str = labels
            # 最初のバッチだけ、例として正解文章と予測した文章を表示。
            if step == 1:
                for j, _ in enumerate( label_str ):
                    print( "target:{}".format( label_str[j] ))
                    print( "predec:{}\n".format( pred_str[j] ))
             
            # 損失計算はなし
 
            #予測値算出もなし
            
            #正解件数算出もなし
                
            # cer 算出
            cer = 100 * metrics_cer.compute(predictions=pred_str, references=label_str)
            cer_sum2 += cer

            # wer 算出
            # 分かち書きして空白区切りに変換
            wer_pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
            wer_label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]            
            wer = 100 * metrics_wer.compute(predictions=wer_pred_str, references=wer_label_str)                
            wer_sum2 += wer                                 
                                 
            val_loss = 0.0
            val_acc = 0.0

            gc.collect()            
            
            # 損失と精度の計算ダミー
            avg_val_loss = val_loss / count
            avg_val_acc = val_acc / count
            # cer と wer の平均値計算。
            avg_cer = cer_sum2 / step
            avc_wer = wer_sum2 / step
            
            #プログレスバーに cer 表示
            phar.set_postfix( cer = avg_cer )   
        
        avg_cer = cer_sum2 / (step )
        #print("avg_cer:{}".format( avg_cer ) )
        avg_wer = wer_sum2 / (step )
        #print("avg_wer:{}".format( avg_wer ))
    
        print (f'Epoch [{(epoch+1)}/{num_epochs+base_epochs}],loss: {avg_train_loss:.5f} acc: {avg_train_acc:.5f} cer: {avg_cer:.5f} wer: {avg_wer:.5f} lr: {lr:}')
        item = np.array([epoch+1, avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc, avg_cer, avg_wer, lr])
        history = np.vstack((history, item))
        paths.append( path )
        
        # modelを torch.save する。
        #epoch が  1以上で前の epoch より cer が小さければ save する。
        if epoch >=1:
            if history[epoch,5] < history[epoch-1,5]:
                torch.save( model, path )
                print( "model {} is saved.".format(path)) 
        #eoich が 0 の時は無条件で save
        elif epoch == 0:
            torch.save( model, path )
            print( "model {} is saved.".format(path))
        # lr を 1/5 にする。 epoch が 2 以上で、二回続けて cer が大きくなったら 1/5 にする。
        # lr が 1e-8 より小さくなる場合は、1/5 にしない。
        if epoch >= 2:
            if history[epoch,5] > history[epoch-1,5] and history[epoch-1,5] > history[epoch-2,5]:
                if lr > 5e-8:
                    lr = lr / 5.0
                    print( "More than 2 times cer increases, lr = lr / 5.0 = {}".format( lr ))
                    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
                    optimizer.param_groups[0]['lr'] = lr
                else:
                    lr = lr 
                    print( "More than 2 times cer increases, but lr =< 5e-8, so lr not change")
        # Early Stopping epoch が 3以上で、三回連続前の cer より大きくなったら。
        if epoch >= 3:
            if history[epoch,5] >= history[epoch-1,5] and history[epoch-1,5] >= history[epoch-2,5] and history[epoch-2,5] >= history[epoch-3,5]:
                print( "Early Stopping. More than 3 times cer increases")
                break
        
    return history, pahts

学習用関数では、val-loss と val-acc は、計算しないので、0 です。validation は、transcribe 関数を用いて、文章を予測し、cer と wer を計算しています。

学習ログ解析

# 学習ログ解析
def evaluate_history(history):
    #損失と精度の確認
    print(f'初期状態: 損失: {history[0,3]:.5f} 精度: {history[0,4]:.5f} cer: {history[0,5]:.5f} wer: {history[0,6]:.5f}') 
    print(f'最終状態: 損失: {history[-1,3]:.5f} 精度: {history[-1,4]:.5f} cer: {history[-1,5]:.5f} wer: {history[-1,6]:.5f}' )

    num_epochs = len(history)
    if num_epochs < 10:
      unit = 1
    else:
      unit = num_epochs // 10

    # 学習曲線の表示 (損失)
    plt.figure(figsize=(9,8))
    plt.plot(history[:,0], history[:,1], 'b', label='訓練')
    #plt.plot(history[:,0], history[:,3], 'k', label='検証')
    plt.xticks(np.arange(0,num_epochs+1, unit))
    plt.xlabel('繰り返し回数')
    plt.ylabel('損失')
    plt.title('学習曲線(損失)')
    plt.legend()
    plt.show()

    # 学習曲線の表示 (精度)
    plt.figure(figsize=(9,8))
    plt.plot(history[:,0], history[:,2], 'b', label='訓練')
    #plt.plot(history[:,0], history[:,4], 'k', label='検証')
    plt.xticks(np.arange(0,num_epochs+1,unit))
    plt.xlabel('繰り返し回数')
    plt.ylabel('精度')
    plt.title('学習曲線(精度)')
    plt.legend()
    plt.show()
    
    # cer, wer
    plt.figure(figsize=(9,8))
    plt.plot(history[:,0], history[:,5], 'b', label='cer')
    plt.plot(history[:,0], history[:,6], 'k', label='wer')
    plt.xticks(np.arange(0,num_epochs+1,unit))
    plt.xlabel('繰り返し回数')
    plt.ylabel('cer,wer %')
    plt.title('cer,wer')
    plt.legend()
    plt.show()

メモリーの確保

#del audio_train
#del labels_train
#del filename_test
#del sentence_test

gc.collect()

学習

# 学習
num_epochs = 20
history, paths = fit(model, optimizer, criterion, num_epochs, 
            batch_input_ids_train, batch_labels_train,
            batch_filename_test, batch_sentence_test,
            device, lr, history, paths)

cer 最小のモデル

history と paths から、cer が最小の model を読みこむ。

#hisotry より最小の cer を探す。
min_cer = 1e10
for i in range( len(history)):
    if min_cer > history[i,5]:
        min_cer = history[i,5]
        
print( "最小の cer は:{}".format(min_cer))

#最小の cer の時に save されたファイル名を探す。
i_atari = 0
filename = ''
for i in range( len(history)):
    if min_cer == history[i,5]:
        i_atari = i
        filename = paths[i]
        beak

#記録する。
f = open('kiroku.txt', 'w', encoding='UTF-8')        
str1 = "その時のエポックは:{}".format( i_atari + 1 )
print(str1)
f.write( str1 + "\n" )
str_history = str( history )
f.write( str_history + "\n" )
str_paths = str( paths )
f.write( str_paths + "\n" )
f.close()

#最小の cer の model を読み込む。
model = torch.load(filename)
print("モデル{}を読み込みました".format( filename ) )

モデルの保存

torch.save(model, 'model_weight.pth')

model.state_dict() の保存

PATH = "whisper-torch-fine-tuning.pt"
torch.save({ 'model_state_dict': model.state_dict(),}, PATH)

結果サマリー

# 結果サマリー
evaluate_history(history)

30秒以上のファイルの文字起こし

result = model.transcribe("/path/to/wav_file", language="ja", task="transcribe", verbose=True)

#print( result["text"] )

です。

評価

合同朝礼で話をする特定の人物の音声データと字幕から作った教師データで学習させました。train 280件、validation 20 件で、validation data に対する cer=16.98, wer=21.21(epoch=1) だったのが cer = 14.51, wer = 18.13 ( epoch =6 ) になりました。10 % 以上の精度向上です。

ちなみに、fine tuning したモデルで、jsut-ver.1.1 の BASIC5000_0003.wav を音声認識すると、

[00:00.000 --> 00:03.740] 上院議員は、私がデータを歪めたと告発した。

のようになりました。

Fine Tuning ではなく、whisper を初期化して学習させるには。

model = whisper.load_model( 'tiny' ) したあとに、次のプログラムを実行すると、whisper のモデルパラメーターを初期化します。tiny model で動作確認済。

# モデルの重みとバイアス(パラメータ)の初期化

def kaiming_init(model):
    """ Kaimingの初期化

    Args:
        model (object): モデル
    """
    i = 1
    for name, param in model.named_parameters():
        print("i:{}".format( i ))
        i += 1
        if name.endswith(".bias"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "bias")
            param.data.fill_(0)
        elif name.startswith("encoder.conv"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "encoder.conv")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))
            #param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
        elif name.endswith("ln.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print("ln.weight")
            param.data.fill_(1)
        elif name.endswith("ln_post.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "ln_post.weight")
            param.data.fill_(1)
        elif name.startswith("layers.0"):
             # The first layer does not have ReLU applied on its input
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print( "layers.0")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))
        elif name.endswith(".mlp.0.weight"):
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print(".mlp.0.weight" )
            param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
        else:
            print( "name:{}".format( name ))
            print( "shape of param:{}".format( param.shape ))
            print("other")
            param.data.normal_(0, 1/math.sqrt(param.shape[1]))       
            
kaiming_init(model)

初期化して学習させてみた。

初期化して、JSUT ver 1.1 BASIC 5000 発話を使って、学習させてみました。19/50 epochs 目の予測データです。train target が訓練用教師データ、train predec が訓練用データの予測文章、val target が評価用教師データ、 val predec が評価用データの予測文章です。訓練用データについては学習の兆候がみられますが、評価用データの方はほとんど学習の効果がみられていません。

train: 100%
612/612 [3:24:28<00:00, 18.55s/it, loss=0.315]
train target:布を斜めに裁ちなさい
train predec:列�詜に裁なさいが

train target:豹はその斑点を変えることはできない
train predec:羹�その斑点を変えることはできないない

train target:筆者はそうした風潮を好まない
train predec:羹�はそうした風潮を好まないない

train target:飛行機は瞬く間に見えなくなった
train predec:列�はは瞬く間に見えくなった

train target:彼女もう浮かれちゃってるよ
train predec:�のはもう浮れちゃってるよ�ら

train target:彼女は浮気な女で本当に誰でも相手にする
train predec:硎のもう�のじなのの本当に誰でも相手にするれている

train target:彼女は夫の到着を待ち焦がれています
train predec:週週コもう夫の到着思待ち焦がれています

train target:彼女は恥じらいの色を隠すために顔をそむけた
train predec:薬のは夫�じらいの色本当隠すために顔を思むせてた

val: 100%
12/12 [01:28<00:00, 6.67s/it, val_loss=0.417]
val target:彼女は私を見るや否やわっと泣き出した
val predec:彼はは夫の到�してして�間ををってだ

val target:彼女は三味線による新しいジャズの演奏法を始めた
val predec:彼はは夫�じで��って明手手実�がをいたでた

val target:彼女はいつも床を綺麗に掃いています
val predec:彼女は夫�のまがをのえるにらですせにさ

val target:彼女の伯母は一日中彼の犬の世話をする
val predec:彼女は浮をのはようの思�味達��まま言をいる

val target:彼女のいうことは的外れである
val predec:彼女は浮をは�のちがを�

val target:彼女がこれほど自分勝手なのは嘆かわしい
val predec:彼女は浮化のに行に求ようをことき書

val target:彼は緻密に立てた計画を実行した
val predec:彼はは�のしてのしてたいれて宼していたしてでした

val target:彼は老けて見える
val predec:彼�は�の�ないにできるえる行

val target:彼は大尉以上の者を全員招集した
val predec:彼女は�のではののれて��してこと�付でしたした
Epoch [19/50],loss: 0.31485 acc: 12.09633, val-loss: 0.41671 val-acc: 0.96311  cer: 106.36080 wer: 131.98696 lr: 0.0007744940502801019
6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4