プログラムの目的
OpenAI の Whisper には、30秒以上の音声ファイルを文字起こしする transcribe 関数があります。驚異的なのは、large モデルで 10 分以上の音声ファイルから字幕ファイルを作っても、メモリーで問題が起きないです。fine tuning したモデルに、この機能を使うために、OpenAI の Whisper を fine tuning するプログラムを作成しました。fine tuning させたところ、きちんと学習しているので、情報の共有をお願いします。
ライブラリーの読み込みなど
最初に、ライブラリーの読み込みと GPU の判別。わたくしの開発環境は、GPU がないので、一応、GPU が使えるようにプログラムは書いたつもりですが、動作確認はしていません。もし、ちゃんと動かなかったら、直して使ってください。CPU では動作確認しました。
import torch
import librosa
import whisper
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import japanize_matplotlib
import evaluate
import gc
import spacy
import ginza
from tqdm.notebook import tqdm
#GPUの利用
device = 0 if torch.cuda.is_available() else "cpu"
print(device)
設定
次に、notebook の設定。赤石雅典さんの「最短でわかる PyTorch & 深層プログラミング」を参考にさせていただきました。fit 関数、evaluate_history 関数なども参考にしています。
# warning表示off
import warnings
warnings.simplefilter('ignore')
# デフォルトフォントサイズ変更
plt.rcParams['font.size'] = 14
# デフォルトグラフサイズ変更
plt.rcParams['figure.figsize'] = (6,6)
# デフォルトで方眼表示ON
plt.rcParams['axes.grid'] = True
# numpyの表示桁数設定
np.set_printoptions(suppress=True, precision=5)
データ
次にデータの読み込み、データの配置は、
current dir------P
|
----metadata.csv
|
----data
| p0001.wav
| p0002.wav
| p0003.wav
・・・
です。metadata.csv の中身は、ファイル名の拡張子をのぞいた basename と教師データです。
p0001,教師データの文章1
p0002,教師データの文章2
p0003,教師データの文章3
・・・
のようです。データの読み込みプログラムは、
#データの読み込み
fr = open('./P/metadata.csv', "r", encoding='UTF-8')
datalist = fr.readlines()
sentence_train = []
filename_test = []
sentence_test = []
for i, line1 in enumerate(datalist):
filename = line1.split( ',' )[0]
#音声ファイルの読み込み
filename = "./P/data/" + filename + ".wav"
audio0, _ = librosa.load( filename, sr = 16000)
line2 = line1.split( ',' )[1]
if i < 280: #train データの数に合わせる
#音声データ
audio_train.append( audio0 )
#正解文章データ
sentence_train.append( line2 )
else: # validation データになる
#音声ファイル名データ
filename_test.append( filename )
#正解文章データ
sentence_test.append( line2 )
print( len( audio_train ))
print( len( sentence_train))
print( len( filename_test))
print( len( sentence_test ))
です。ここで、train データの数を指定してください。
モデルの読み込み
次は、whisper モデルと tokenizer の読み込みおよび定数の設定です。
#whisper モデルの読み込み
model = whisper.load_model( 'small' )
#tokenizer の読み込みと定数の設定
whisper_tok = whisper.tokenizer.get_tokenizer(True, task="transcribe", language="ja")
tokenizer = whisper_tok.tokenizer
tokenizer.pad_token = tokenizer.eos_token
eos = tokenizer.eos_token_id
sot = tokenizer("<|startoftranscript|>").input_ids[0]
lang_id = tokenizer("<|ja|>").input_ids[0]
task = tokenizer("<|transcribe|>").input_ids[0]
t_stamp = tokenizer("<|notimestamps|>").input_ids[0]
ここで、デコーダーの入力データの始まりが、sot, lang_id, task, t_stamp ・・・, eosで、教師データが、 lang_id, task, t_stamp, ・・・, eos であるというこは、
のページで認識しました。それまでは、単純に、<sos>
・・・ <eos>
だと思っていました。
元とする whisper のモデルですが、small モデルでうまく学習できることが確認出来たら、large モデルで学習させてみることをお勧めします。修正は、読み込むモデルを large にすることと、メモリーを消費するため、batch_size を小さくすることくらいです。
バッチ
データをバッチに分割します。
# データをバッチに分割する。
batch_size = 2 #学習、バリデーション共通の batch_size メモリー使用量を見ながら調整してください。
input_ids_train = audio_train
#文章データを tokenizer で、token に変換
labels_train = tokenizer(sentence_train).input_ids
#バッチ分割
split = len(input_ids_train) // batch_size
#音声データ train 用
batch_input_ids_train = np.array_split( input_ids_train, split )
print( len(batch_input_ids_train))
print( len(batch_input_ids_train[0]))
print( len(batch_input_ids_train[0][0]))
# 正解 token データ、target, y train 用。
batch_labels_train = np.array_split( labels_train, split )
print( len(batch_labels_train))
print( len(batch_labels_train[0]))
print( len(batch_labels_train[0][0]))
split = len(filename_test) // batch_size
#音声ファイル名データ validation 用
batch_filename_test = np.array_split( filename_test, split )
print( len(batch_filename_test ))
#正解文章データ validation 用
batch_sentence_test = np.array_split( sentence_test, split )
print( len(batch_sentence_test ))
サンプルデータの確認
#サンプルデータの確認
print( len(batch_input_ids_train[10][0] ) )
print( len(batch_labels_train[10][0] ) )
モデルの確認
# ネットワークの概要表示
print(model)
summary の表示
#summary の表示
#!pip install torchsummaryX
#from torchinfo import summary
from torchsummaryX import summary
summary(model=model, x=torch.zeros((batch_size,80,3000)), tokens=torch.zeros((batch_size,448),dtype=torch.long))
学習にかかわる定数などの設定
# GPUの利用
model = model.to(device)
# 学習率
lr = 0.0001
# 損失関数定義
criterion = nn.CrossEntropyLoss(ignore_index=-100)
# 最適化関数定義
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# historyファイルと torch.save ファイルの path を初期化する
history = np.zeros((0, 8))
paths = []
to_pad_to_mel 関数
音声データを 30秒に pad or trim し、log-mel フィルタバンク化する関数。学習用関数 fit で使います。
のページより引用。
#音声データ を 30秒に pad or trim して、mel フィルターバンクにする。
def to_pad_to_mel(array):
"""Static function which:
1. Pads/trims a list of audio arrays to a max length of 30s
2. Computes log-mel filter coefficients from padded/trimmed audio sequences
Inputs:
array: list of audio arrays
Returns:
input_ids: torch.tensor of log-mel filter bank coefficients
"""
padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
input_ids = whisper.log_mel_spectrogram(padded_input)
return input_ids
学習用関数
中で、model( input_ids, tokens=dec_input_ids )としていますが、encoder と decoder が含まれています。
# 学習用関数
def fit(model, optimizer, criterion, num_epochs,
batch_input_ids_train, batch_labels_train,
batch_filename_test, batch_sentence_test,
device, lr, history, pahts):
base_epochs = len(history)
# cer, wer 計算ライブラリーの読み込み
metrics_cer = evaluate.load("cer")
metrics_wer = evaluate.load("wer")
# wer 計算のために分かち書きをするので、分かち書き用。
nlp = spacy.load("ja_ginza")
ginza.set_split_mode(nlp, "C") # CはNEologdの意らしいです
#メモリー確保
gc.collect()
# epoch のループ
for epoch in range(base_epochs, num_epochs+base_epochs):
train_loss = 0
train_acc = 0
val_loss = 0
val_acc = 0
#torch.save するファイルのファイル名。
path = "whisper-torch-fine-tuning-" + (str(epoch+1)) + ".pth"
#print( "path:{}".format(path))
#訓練フェーズ
model.train()
count = 0
#バッチのループ、プログレスバー対応
phar = tqdm( range( len(batch_input_ids_train) ), desc='train' )
for i in phar:
# model へ入力する x データ
input_ids = batch_input_ids_train[i]
# model に対して教師データとなる y, target データ
labels = batch_labels_train[i]
gc.collect()
count += len(labels)
# audio データを melフィルタバンクに変換し、padding する。 (batch_size, 80, 30000)
input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids]).to(device)
#教師データから、decoder の入力に使うデータを作るために label0 をとっておく。
labels0 = labels
#教師データ
labels = [ [lang_id] + [task] + [t_stamp] + lab + [eos] for lab in labels]
#デコーダーへ力
labels2 = [ [sot] + [lang_id] + [task] + [t_stamp] + lab2 +[eos] for lab2 in labels0 ]
# finally, pad the target labels to max_len
label_lengths = [len(lab) for lab in labels]
label2_lengths = [len(lab2) for lab2 in labels2]
max_label_len = max(label_lengths)
max_label2_len = max(label2_lengths)
labels = [np.pad(lab, (0, max_label2_len - lab_len ), 'constant',
constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
#教師データ
labels = torch.tensor( np.array(labels), requires_grad=False, dtype=torch.long).to(device)
dec_input_ids = [np.pad(lab2, (0, max_label2_len - lab2_len ), 'constant',
constant_values=eos) for lab2, lab2_len in zip(labels2, label2_lengths)]
#デコーダーへ入力
dec_input_ids = torch.tensor( np.array(dec_input_ids), requires_grad=False, dtype=torch.long).to(device)
del labels0
del labels2
del label_lengths
del label2_lengths
del max_label_len
del max_label2_len
gc.collect()
# 勾配の初期化
optimizer.zero_grad()
# 予測計算
logits = model(input_ids, tokens=dec_input_ids )
# 損失計算
loss = criterion( logits.view(-1, logits.size(-1)), labels.view(-1) )
train_loss += loss.item()
# 勾配計算
loss.backward()
# パラメータ修正
optimizer.step()
# 予測値算出
predicted = torch.max(logits, 2)[1]
# 正解件数算出
train_acc += (predicted == labels).sum()
del logits
del loss
del predicted
gc.collect()
# 損失と精度の計算
avg_train_loss = train_loss / count
avg_train_acc = train_acc / count
#プログレスバーに loss 表示
phar.set_postfix( loss = avg_train_loss )
#予測フェーズ
model.eval()
count = 0
step = 0
cer_sum2 = 0
wer_sum2 = 0
#バッチのループ、プログレスバー対応
phar = tqdm( range( len(batch_filename_test) ), desc='val' )
for i in phar:
#print( "i:{}".format( i ))
#音声データのファイル名,x
filenames = batch_filename_test[i]
#正解データ。target, y
labels = batch_sentence_test[i]
gc.collect()
count += len(labels)
step += 1
gc.collect()
# 予測計算 本番用の model.transcribe 関数を用いて予測するため、val_loss と val_acc は計算しない。
pred_str = []
for filename in filenames:
result = model.transcribe(filename, language="ja", task="transcribe")
pred_str.append( result['text'] )
label_str = labels
# 最初のバッチだけ、例として正解文章と予測した文章を表示。
if step == 1:
for j, _ in enumerate( label_str ):
print( "target:{}".format( label_str[j] ))
print( "predec:{}\n".format( pred_str[j] ))
# 損失計算はなし
#予測値算出もなし
#正解件数算出もなし
# cer 算出
cer = 100 * metrics_cer.compute(predictions=pred_str, references=label_str)
cer_sum2 += cer
# wer 算出
# 分かち書きして空白区切りに変換
wer_pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
wer_label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]
wer = 100 * metrics_wer.compute(predictions=wer_pred_str, references=wer_label_str)
wer_sum2 += wer
val_loss = 0.0
val_acc = 0.0
gc.collect()
# 損失と精度の計算ダミー
avg_val_loss = val_loss / count
avg_val_acc = val_acc / count
# cer と wer の平均値計算。
avg_cer = cer_sum2 / step
avc_wer = wer_sum2 / step
#プログレスバーに cer 表示
phar.set_postfix( cer = avg_cer )
avg_cer = cer_sum2 / (step )
#print("avg_cer:{}".format( avg_cer ) )
avg_wer = wer_sum2 / (step )
#print("avg_wer:{}".format( avg_wer ))
print (f'Epoch [{(epoch+1)}/{num_epochs+base_epochs}],loss: {avg_train_loss:.5f} acc: {avg_train_acc:.5f} cer: {avg_cer:.5f} wer: {avg_wer:.5f} lr: {lr:}')
item = np.array([epoch+1, avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc, avg_cer, avg_wer, lr])
history = np.vstack((history, item))
paths.append( path )
# modelを torch.save する。
#epoch が 1以上で前の epoch より cer が小さければ save する。
if epoch >=1:
if history[epoch,5] < history[epoch-1,5]:
torch.save( model, path )
print( "model {} is saved.".format(path))
#eoich が 0 の時は無条件で save
elif epoch == 0:
torch.save( model, path )
print( "model {} is saved.".format(path))
# lr を 1/5 にする。 epoch が 2 以上で、二回続けて cer が大きくなったら 1/5 にする。
# lr が 1e-8 より小さくなる場合は、1/5 にしない。
if epoch >= 2:
if history[epoch,5] > history[epoch-1,5] and history[epoch-1,5] > history[epoch-2,5]:
if lr > 5e-8:
lr = lr / 5.0
print( "More than 2 times cer increases, lr = lr / 5.0 = {}".format( lr ))
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer.param_groups[0]['lr'] = lr
else:
lr = lr
print( "More than 2 times cer increases, but lr =< 5e-8, so lr not change")
# Early Stopping epoch が 3以上で、三回連続前の cer より大きくなったら。
if epoch >= 3:
if history[epoch,5] >= history[epoch-1,5] and history[epoch-1,5] >= history[epoch-2,5] and history[epoch-2,5] >= history[epoch-3,5]:
print( "Early Stopping. More than 3 times cer increases")
break
return history, pahts
学習用関数では、val-loss と val-acc は、計算しないので、0 です。validation は、transcribe 関数を用いて、文章を予測し、cer と wer を計算しています。
学習ログ解析
# 学習ログ解析
def evaluate_history(history):
#損失と精度の確認
print(f'初期状態: 損失: {history[0,3]:.5f} 精度: {history[0,4]:.5f} cer: {history[0,5]:.5f} wer: {history[0,6]:.5f}')
print(f'最終状態: 損失: {history[-1,3]:.5f} 精度: {history[-1,4]:.5f} cer: {history[-1,5]:.5f} wer: {history[-1,6]:.5f}' )
num_epochs = len(history)
if num_epochs < 10:
unit = 1
else:
unit = num_epochs // 10
# 学習曲線の表示 (損失)
plt.figure(figsize=(9,8))
plt.plot(history[:,0], history[:,1], 'b', label='訓練')
#plt.plot(history[:,0], history[:,3], 'k', label='検証')
plt.xticks(np.arange(0,num_epochs+1, unit))
plt.xlabel('繰り返し回数')
plt.ylabel('損失')
plt.title('学習曲線(損失)')
plt.legend()
plt.show()
# 学習曲線の表示 (精度)
plt.figure(figsize=(9,8))
plt.plot(history[:,0], history[:,2], 'b', label='訓練')
#plt.plot(history[:,0], history[:,4], 'k', label='検証')
plt.xticks(np.arange(0,num_epochs+1,unit))
plt.xlabel('繰り返し回数')
plt.ylabel('精度')
plt.title('学習曲線(精度)')
plt.legend()
plt.show()
# cer, wer
plt.figure(figsize=(9,8))
plt.plot(history[:,0], history[:,5], 'b', label='cer')
plt.plot(history[:,0], history[:,6], 'k', label='wer')
plt.xticks(np.arange(0,num_epochs+1,unit))
plt.xlabel('繰り返し回数')
plt.ylabel('cer,wer %')
plt.title('cer,wer')
plt.legend()
plt.show()
メモリーの確保
#del audio_train
#del labels_train
#del filename_test
#del sentence_test
gc.collect()
学習
# 学習
num_epochs = 20
history, paths = fit(model, optimizer, criterion, num_epochs,
batch_input_ids_train, batch_labels_train,
batch_filename_test, batch_sentence_test,
device, lr, history, paths)
cer 最小のモデル
history と paths から、cer が最小の model を読みこむ。
#hisotry より最小の cer を探す。
min_cer = 1e10
for i in range( len(history)):
if min_cer > history[i,5]:
min_cer = history[i,5]
print( "最小の cer は:{}".format(min_cer))
#最小の cer の時に save されたファイル名を探す。
i_atari = 0
filename = ''
for i in range( len(history)):
if min_cer == history[i,5]:
i_atari = i
filename = paths[i]
beak
#記録する。
f = open('kiroku.txt', 'w', encoding='UTF-8')
str1 = "その時のエポックは:{}".format( i_atari + 1 )
print(str1)
f.write( str1 + "\n" )
str_history = str( history )
f.write( str_history + "\n" )
str_paths = str( paths )
f.write( str_paths + "\n" )
f.close()
#最小の cer の model を読み込む。
model = torch.load(filename)
print("モデル{}を読み込みました".format( filename ) )
モデルの保存
torch.save(model, 'model_weight.pth')
model.state_dict() の保存
PATH = "whisper-torch-fine-tuning.pt"
torch.save({ 'model_state_dict': model.state_dict(),}, PATH)
結果サマリー
# 結果サマリー
evaluate_history(history)
30秒以上のファイルの文字起こし
result = model.transcribe("/path/to/wav_file", language="ja", task="transcribe", verbose=True)
#print( result["text"] )
です。
評価
合同朝礼で話をする特定の人物の音声データと字幕から作った教師データで学習させました。train 280件、validation 20 件で、validation data に対する cer=16.98, wer=21.21(epoch=1) だったのが cer = 14.51, wer = 18.13 ( epoch =6 ) になりました。10 % 以上の精度向上です。
ちなみに、fine tuning したモデルで、jsut-ver.1.1 の BASIC5000_0003.wav を音声認識すると、
[00:00.000 --> 00:03.740] 上院議員は、私がデータを歪めたと告発した。
のようになりました。
Fine Tuning ではなく、whisper を初期化して学習させるには。
model = whisper.load_model( 'tiny' ) したあとに、次のプログラムを実行すると、whisper のモデルパラメーターを初期化します。tiny model で動作確認済。
# モデルの重みとバイアス(パラメータ)の初期化
def kaiming_init(model):
""" Kaimingの初期化
Args:
model (object): モデル
"""
i = 1
for name, param in model.named_parameters():
print("i:{}".format( i ))
i += 1
if name.endswith(".bias"):
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print( "bias")
param.data.fill_(0)
elif name.startswith("encoder.conv"):
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print( "encoder.conv")
param.data.normal_(0, 1/math.sqrt(param.shape[1]))
#param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
elif name.endswith("ln.weight"):
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print("ln.weight")
param.data.fill_(1)
elif name.endswith("ln_post.weight"):
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print( "ln_post.weight")
param.data.fill_(1)
elif name.startswith("layers.0"):
# The first layer does not have ReLU applied on its input
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print( "layers.0")
param.data.normal_(0, 1/math.sqrt(param.shape[1]))
elif name.endswith(".mlp.0.weight"):
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print(".mlp.0.weight" )
param.data.normal_(0, math.sqrt(2)/math.sqrt(param.shape[1]))
else:
print( "name:{}".format( name ))
print( "shape of param:{}".format( param.shape ))
print("other")
param.data.normal_(0, 1/math.sqrt(param.shape[1]))
kaiming_init(model)
初期化して学習させてみた。
初期化して、JSUT ver 1.1 BASIC 5000 発話を使って、学習させてみました。19/50 epochs 目の予測データです。train target が訓練用教師データ、train predec が訓練用データの予測文章、val target が評価用教師データ、 val predec が評価用データの予測文章です。訓練用データについては学習の兆候がみられますが、評価用データの方はほとんど学習の効果がみられていません。
train: 100%
612/612 [3:24:28<00:00, 18.55s/it, loss=0.315]
train target:布を斜めに裁ちなさい
train predec:列�詜に裁なさいが
train target:豹はその斑点を変えることはできない
train predec:羹�その斑点を変えることはできないない
train target:筆者はそうした風潮を好まない
train predec:羹�はそうした風潮を好まないない
train target:飛行機は瞬く間に見えなくなった
train predec:列�はは瞬く間に見えくなった
train target:彼女もう浮かれちゃってるよ
train predec:�のはもう浮れちゃってるよ�ら
train target:彼女は浮気な女で本当に誰でも相手にする
train predec:硎のもう�のじなのの本当に誰でも相手にするれている
train target:彼女は夫の到着を待ち焦がれています
train predec:週週コもう夫の到着思待ち焦がれています
train target:彼女は恥じらいの色を隠すために顔をそむけた
train predec:薬のは夫�じらいの色本当隠すために顔を思むせてた
val: 100%
12/12 [01:28<00:00, 6.67s/it, val_loss=0.417]
val target:彼女は私を見るや否やわっと泣き出した
val predec:彼はは夫の到�してして�間ををってだ
val target:彼女は三味線による新しいジャズの演奏法を始めた
val predec:彼はは夫�じで��って明手手実�がをいたでた
val target:彼女はいつも床を綺麗に掃いています
val predec:彼女は夫�のまがをのえるにらですせにさ
val target:彼女の伯母は一日中彼の犬の世話をする
val predec:彼女は浮をのはようの思�味達��まま言をいる
val target:彼女のいうことは的外れである
val predec:彼女は浮をは�のちがを�
val target:彼女がこれほど自分勝手なのは嘆かわしい
val predec:彼女は浮化のに行に求ようをことき書
val target:彼は緻密に立てた計画を実行した
val predec:彼はは�のしてのしてたいれて宼していたしてでした
val target:彼は老けて見える
val predec:彼�は�の�ないにできるえる行
val target:彼は大尉以上の者を全員招集した
val predec:彼女は�のではののれて��してこと�付でしたした
Epoch [19/50],loss: 0.31485 acc: 12.09633, val-loss: 0.41671 val-acc: 0.96311 cer: 106.36080 wer: 131.98696 lr: 0.0007744940502801019