10
9

More than 1 year has passed since last update.

【試行錯誤】OpenAI Whisperを活用した日本語歌詞のforced-alignment その1:下調べとワンパス

Last updated at Posted at 2022-10-12

概要

whisperwav2vecを使った音素レベルのaudio-lyric-alignmentの方法を試行錯誤します。とりあえず下調べして、なんとなくの方針を決めて、精度はともかく、それっぽいalignmentの結果が出てくることを目指します。

シリーズ一覧は以下
【試行錯誤】OpenAI Whisperを活用した日本語歌詞のforced-alignment リンクまとめ

背景

forced-alignment

forced-alignmentはある音声(発話)とその書き起こしテキストにおいて、テキストの各音素が、音声のどの位置(時刻)に対応するかを推定する技術です。
https://linguistics.berkeley.edu/plab/guestwiki/index.php?title=Forced_alignment

この技術の応用先として、動画への自動字幕付与などがあります。
forced-alignmentのうち、歌唱音声と歌詞の時刻対応付けはaudio-lyric alignmentなどと呼ばれ、歌唱でない一般的な対話音声のforced-alignmentよりも難易度が高いです。一方で、カラオケ動画の自動作成など、audio-lyric alignmentの応用先も多く、実用化が望まれます。

日本語のaudio-lyric alignmentを手軽に使えるアプリ等は現状では存在しない一方、後述のwhisperの登場により、あまり複雑なことをせずともかなりいい線まで行けるのではないかという思いが芽生えたため、試してみることにしました。

音声認識モデルwhisper

whisperはopenaiが開発した2022年10月時点で最高クラスの性能をもつ音声認識モデルであり、日本語を含む多言語に対応していること、オープンソース(MITライセンス)であることなどが特徴です。
whisperはBGMの存在などの理由により、単なる発話音声よりも難易度が高い歌唱音声の認識も可能であることが、デモで示唆されています。
そこでwhisperを活用することで、高い精度での日本語audio-lyric alignmentができる可能性があります。

軽い調査

まず、whisperでは、segment(公式の表現だとphrase?)レベルでのタイムスタンプ取得できているようです。segmentというのは文章とかのレベルの単位なので、対話動画の字幕付けなどには有用ですが、forced-alignmentの細かさには足りません。

「whisper forced-alignment」で検索してみると、同じことを考えている人がまあまあいるようです。
https://github.com/openai/whisper/discussions/52
https://github.com/openai/whisper/discussions/3
翻訳に自信がありませんが

みたいなことが(たぶん)言われていました。
ということでwhisper単体だとword-levelまでのようです。
word-levelで目的が達成できる場合には、上記のdisucssion等を探っていくと良いと思います。
今回は音素レベルのalignmentがしたいので、もう少し、掘ってみます。

音素レベルのforced-alignmentをしたい場合には、facebookのwav2vecと組み合わせるという方法があるようです。
wav2vec参考:論文日本語解説

wav2vectorchaudioから呼び出して、forced-alignmentする方法はpytorchのチュートリアルで解説されています。

このチュートリアルさえあれば、whisperが不要(歌詞は別途取得できている前提なので)に思えるのですが、実際に実行してみると10秒前後の長さの音声ファイルだと、(少し工夫すれば日本語でも)ある程度正確にできるものの、1曲まるごとの長さになると、処理がタイムアウトしてしまいます。
そこで、whisperを使ってphraseレベルの(ゆるい)alignmentを実行してから、phraseごとに音声ファイルを区切って、wav2vecとtorchaudioでforced-alignmentする、という方法が筋が良さそうです。
(phraseレベルのalignmentのためだけにwhisperは若干オーバーキルかもしれませんが、他の方法がすぐに思いつく方法が意外とないので、whisperに頼ります)

ということで更に探してみると、Lhosteというライブラリが数日前にwhisperに対応し、whisperとwav2vecを組み合わせたword-level alignmentの機能を提供してくれているようです。
https://stackoverflow.com/questions/73822353/how-can-i-get-word-level-timestamps-in-openais-whisper-asr

なのでこちらも最新版(v1.9.0)をインストールして手元の環境(Colab)で試してみたのですが、エラーが出てうまく行かなかったのと、phoneme-levelのalignmentをするための改良がどのくらいやりやすいか未知数だったので、一旦採用を見送りました。

もう少しカプセル化されていない書き方で上記のアイデア(whisperでphraseに分けてからtorcuaudioとwav2vecで処理)を実装してくれている人がいたので、こっちを参考に実装を試みます。
https://github.com/johnafish/whisperer

実装(試行錯誤)

whisperによる日本語歌詞認識

まずwhisperによる日本語歌詞の認識精度が悪いとすべての前提が崩れるので、試してみます。

準備

Google Colabで試します(ときどきローカルでも作業しています)。ライブラリのインストールと性能の異なる3種類のモデルをロードしておきます。

!pip install git+https://github.com/openai/whisper.git
import whisper
import json
basemodel = whisper.load_model("base")
mediummodel = whisper.load_model("medium")
largemodel = whisper.load_model("large")

音声認識結果を書き出すための関数を作っておきます。

def transcribe(audiofilepath, model, outputjsonpath):
  result = model.transcribe(audiofilepath)
  print(result["text"])
  with open(outputjsonpath, "w") as f:
    json.dump(result, f, ensure_ascii=False, indent=2)
  return result

Character Error Rate(CER)を計算するための関数を作っておきます。
CERは編集距離÷正解文字数として計算できます。
参考:音声認識の精度測定、単語誤り率(WER)と文字誤り率(CER)

# 編集距離計算用のライブラリ
!pip install editdistance
import editdistance
def character_error_rate(correct, pred):
  dist = editdistance.eval(correct, pred)
  return dist / max(len(correct),1)

音源

YOASOBI「夜に駆ける」のwavファイルを使用しました。

結果:認識文字列

音源ファイルをColabにアップロードしたうえで、以下のコードを実行します。

audiofilepath = "yorunikakeru.wav"
correct_text = "<<<正解歌詞>>>"

baseresult = transcribe(audiofilepath, basemodel, "baseresult_{}.json".format(audiofilepath.split(".")[0]))
print("baseresult: {}".format(character_error_rate(correct_text, baseresult["text"])))

mediumresult = transcribe(audiofilepath, mediummodel, "mediumresult_{}.json".format(audiofilepath.split(".")[0]))
print("mediumresult: {}".format(character_error_rate(correct_text, mediumresult["text"])))

transcribe(audiofilepath, largemodel, "largeresult_{}.json".format(audiofilepath.split(".")[0]))
print("largeresult: {}".format(character_error_rate(correct_text, largeresult["text"])))

4分半くらいの音声ファイルに対して、認識にかかった時間はbase, medium, largeでそれぞれ10秒、30秒、1分くらいでした。

CERは以下のような感じでした。

baseresult: 0.19690576652601968
mediumresult: 0.17862165963431786
largeresult: 0.06469760900140648

largeのスコアが圧倒的に良い結果となりました。認識結果を見比べてみると、CERの要因のほとんどはカギカッコの有無など発音と関係ない要因であって、発音に関係する部分はほとんど完璧でした。正直ここまで良いと思っておらず、驚きです。

ただ以下のような不安な点もありました。

  • 何回か試しているとlargeモデルでも、冒頭30秒くらいしか認識されないことがある
  • 末尾に「サブタイトルをご視聴頂きましてありがとうございました!」という謎の発話が認識される(ひょっとしてYouTubeによくある末尾の感謝や挨拶を過学習している?)

transcribe関数のオプション引数はデフォルトを使っているため、ここを細かくチューニングできれば一部解決する部分はあるかもしれません。今後の課題とし、今回はうまく行ったやつを使います。

結果:segmentのタイムスタンプ

認識結果に含まれるsegmentのタイムスタンプが正しそうか検証します。
数値だけ眺めてもわかりにくいので、音楽と同時に字幕として流して検証してみます。

まずffmpegで黒背景で音楽だけが流れる動画を作ります。
以下を参考にさせていただきました。
FFmpegで音声ファイルと画像1枚から動画を作成してみた

ffmpeg \
    -loop 1 \
    -r 30000/1001 \
    -i black_picture.png -i yorunikakeru.wav \
    -vcodec libx264 \
    -acodec aac -strict experimental -ab 320k -ac 2 -ar 48000 \
    -pix_fmt yuv420p \
    -shortest \
    yorunikakeru.mp4

(黒い画像はパワポで適当に作りました)

次にlargeresult_yorunikakeru.jsonからsrtファイルを作ります。
参考:字幕ファイル (.srt) の書き方

import json

# resultファイルの読み込み
PATH = "largeresult_yorunikakeru.json"
with open(PATH) as f:
  result = json.load(f)

# srtファイルの各要素を作成
srtrows = []
for i, seg in enumerate(result["segments"]):
  startmsec, endmsec = int(seg["start"]*1000), int(seg["end"]*1000)
  start_hour = startmsec//3600000
  start_minute = startmsec%3600000 // 60000
  start_second = startmsec%60000 // 1000
  start_msec = startmsec%1000
  end_hour = endmsec//3600000
  end_minute = endmsec%3600000//60000
  end_second = endmsec%60000//1000
  end_msec = endmsec%1000

  #以下みたいな形式にする
  """
  1
  00:00:00,000 --> 00:00:02,000
  ネコ1 ふむ
 
  2
  00:00:02,000 --> 00:00:04,000
  ネコ2 むー
  """
  row = "{}\n{}:{}:{},{} --> {}:{}:{},{}\n{}".format(i+1, str(start_hour).zfill(2), str(start_minute).zfill(2), str(start_second).zfill(2), str(start_msec).zfill(3), str(end_hour).zfill(2), str(end_minute).zfill(2), str(end_second).zfill(2), str(end_msec).zfill(3), seg["text"])
  srtrows.append(row)

with open("phrases.srt", "w") as f:
  f.write("\n\n".join(srtrows))
phrases.srt
1
00:00:00,000 --> 00:00:06,000
沈むように溶けてゆくように

2
00:00:06,000 --> 00:00:28,000
二人だけの空が広がる夜に

3
00:00:28,000 --> 00:00:38,000
さよならだけだったその一言で全てが分かった
...

srtファイルを読み込み可能な動画プレイヤー(VLCメディアプレイヤーなど)で字幕の表示タイミングを確認します。

あるいは、ffmpegでsrtファイルに基づく字幕付きの動画を作ってもよいです。ただ時間がかかります。
参考:FFmpegを使って動画に字幕を焼く

ffmpeg -i yorunikakeru.mp4 -vf subtitles=phrases.srt out.mp4

この方法で確認をしてみると、タイムスタンプが割と不安定なことに気づきます。

  • BGMも発話区間に含まれてしまう(逆に発話区間を非発話区間とすることはあまりなさそうでそこは救い)。
  • transcribeの出力のタイムスタンプの精度は1秒単位で荒い。歌詞が連続した部分を考えると0.1秒の精度はほしい。

ということでこの出力をそのまま使うには不安が残ります。
BGMが含まれてしまう問題は手動で修正するくらいしか対応が思いつきません。
タイムスタンプの精度の粗さについては、次段階でsegmentを前後に1秒ずつ長めに切り出すことで対策になるでしょうか? とりあえず先に進みます。

wav2vecによるforced-alignment

whisperで検出した各segmentに対してforced-alignmentをします。
コードはtorchaudioのチュートリアル先駆者の方を参考にしつつ、以下の変更を加えます。

  • チュートリアルでは、tokenレベル、音素レベル、単語レベルの順にmergeを行っているため、音素レベルのmerge部分(merge_repeats関数)までを使う。
  • 入力のサンプリングレートをwav2vecのフレームレートに合わせる(resampleする)
  • WAV2VEC2が日本語対応していないため、歌詞をローマ字に変換して入力する
  • segmentのタイムスタンプを前後に1秒長くとって入力する(これは試したらやらないほうがマシでした)

ライブラリをインストールします。

!pip install romkan
!pip install mecab-python3 unidic-lite
!pip install pydub num2words

歌詞をローマ字に変換します。

import MeCab
import romkan
import json

#whisperの認識結果情報を取得
with open("largeresult_yorunikakeru.json") as f:
  result = json.load(f)

#漢字かな交じりをカタカナに変換
mecab = MeCab.Tagger()
def get_yomi(text):
  tokens = mecab.parse(text).splitlines()[:-1]
  yomi = ""
  for token in tokens:
    yomi += token.split("\t")[1]
  return yomi
#get_yomi("吾輩は猫である")

# 入力テキストをローマ字に変換
segments = []
for seg in result["segments"]:
  text = seg["text"]
  yomi = get_yomi(seg["text"]) #カタカナに直したもの
  roma = romkan.to_roma(yomi) #ローマ字に直したもの
  romasep = romkan.to_roma("".join("|".join(list(yomi.replace("",""))))) #あとでカナ単位に復元しやすいようにローマ字の間にseparatorをいれたもの
  obj = {
      "start": seg["start"]
      , "end": seg["end"]
      #"start": max(seg["start"] - 1,0) #精度の粗さを考慮して区間を長めにとる。やらないほうがマシ
      #, "end": seg["end"]+1 #精度の粗さを考慮して区間を長めに取る。やらないほうがマシ
      , "text": text
      , "yomi": yomi
      , "roma": roma
      , "romasep": romasep
  }
  segments.append(obj)
print(segments)
[{'start': 0.0, 'end': 6.0, 'text': '沈むように溶けてゆくように', 'yomi': 'シズムヨーニトケテユクヨーニ', 'roma': 'shizumuyo-nitoketeyukuyo-ni', 'romasep': 'shi|zu|mu|yo|ni|to|ke|te|yu|ku|yo|ni'}, ...

forced-alignmentの関数を定義して実行します(参考)。

import tensorflow as tf
import torch
import torchaudio
from datetime import timedelta
from dataclasses import dataclass
from srt import Subtitle, compose
from pydub import AudioSegment
import re
import num2words

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.cuda.empty_cache()

torch.random.manual_seed(0)

def force_align(SPEECH_FILE, transcript, start_index, start_time):
    bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
    model = bundle.get_model().to(device)
    labels = bundle.get_labels()
    with torch.inference_mode():
        waveform, waveform_sample_rate = torchaudio.load(SPEECH_FILE)
        # waveformのサンプルレートをbundleに合わせる
        waveform = torchaudio.transforms.Resample(waveform_sample_rate, bundle.sample_rate)(waveform)
        waveform_sample_rate = bundle.sample_rate
        emissions, _ = model(waveform.to(device))
        emissions = torch.log_softmax(emissions, dim=-1)

    emission = emissions[0].cpu().detach()

    dictionary = {c: i for i, c in enumerate(labels)}

    tokens = [dictionary[c] for c in transcript]

    def get_trellis(emission, tokens, blank_id=0):
        num_frame = emission.size(0)
        num_tokens = len(tokens)

        # Trellis has extra diemsions for both time axis and tokens.
        # The extra dim for tokens represents <SoS> (start-of-sentence)
        # The extra dim for time axis is for simplification of the code.
        trellis = torch.empty((num_frame + 1, num_tokens + 1))
        trellis[0, 0] = 0
        trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
        trellis[0, -num_tokens:] = -float("inf")
        trellis[-num_tokens:, 0] = float("inf")

        for t in range(num_frame):
            trellis[t + 1, 1:] = torch.maximum(
                # Score for staying at the same token
                trellis[t, 1:] + emission[t, blank_id],
                # Score for changing to the next token
                trellis[t, :-1] + emission[t, tokens],
            )
        return trellis


    trellis = get_trellis(emission, tokens)

    @dataclass
    class Point:
        token_index: int
        time_index: int
        score: float


    def backtrack(trellis, emission, tokens, blank_id=0):
        # Note:
        # j and t are indices for trellis, which has extra dimensions
        # for time and tokens at the beginning.
        # When referring to time frame index `T` in trellis,
        # the corresponding index in emission is `T-1`.
        # Similarly, when referring to token index `J` in trellis,
        # the corresponding index in transcript is `J-1`.
        j = trellis.size(1) - 1
        t_start = torch.argmax(trellis[:, j]).item()

        path = []
        for t in range(t_start, 0, -1):
            # 1. Figure out if the current position was stay or change
            # Note (again):
            # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
            # Score for token staying the same from time frame J-1 to T.
            stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
            # Score for token changing from C-1 at T-1 to J at T.
            changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

            # 2. Store the path with frame-wise probability.
            prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
            # Return token index and time index in non-trellis coordinate.
            path.append(Point(j - 1, t - 1, prob))

            # 3. Update the token
            if changed > stayed:
                j -= 1
                if j == 0:
                    break
        else:
            raise ValueError("Failed to align")
        return path[::-1]


    path = backtrack(trellis, emission, tokens)

    # Merge the labels
    @dataclass
    class Segment:
        label: str
        start: int
        end: int
        score: float

        def __repr__(self):
            return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

        @property
        def length(self):
            return self.end - self.start


    def merge_repeats(path):
        i1, i2 = 0, 0
        segments = []
        while i1 < len(path):
            while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                i2 += 1
            score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
            segments.append(
                Segment(
                    transcript[path[i1].token_index],
                    path[i1].time_index,
                    path[i2 - 1].time_index + 1,
                    score,
                )
            )
            i1 = i2
        return segments

    # wordレベルのalignmentは不要なのでmerge_words関数は削除

    segments = merge_repeats(path)
    subs = []
    for i,word in enumerate(segments):
        ratio = waveform.size(1) / (trellis.size(0) - 1)
        x0 = int(ratio * word.start)
        x1 = int(ratio * word.end)
        # wav2vec(bundle)ではなく入力のsample_rateで割る
        #start = timedelta(seconds=start_time + x0 / bundle.sample_rate)
        #end = timedelta(seconds=start_time + x1 / bundle.sample_rate )
        start = timedelta(seconds=start_time + x0 / waveform_sample_rate) 
        end = timedelta(seconds=start_time + x1 / waveform_sample_rate )
        # subsの要素はstr形式ではなく、dict形式にする。この後関数外でmergeするので
        #subtitle = Subtitle(start_index+i, start, end, word.label)
        #subs.append(subtitle) 
        subs.append({
            "text": word.label
            , "start": start
            , "end":end
        })
    return subs

SPEECH_FILE = "yorunikakeru.wav"



print("Starting to force alignment...")
start_index = 0
total_subs = []
for i,segment in enumerate(segments):
    #text = segment["text"]
    text = segment["romasep"]
    audioSegment = AudioSegment.from_wav(SPEECH_FILE)[segment["start"]*1000:segment["end"]*1000]
    audioSegment.export(str(i)+'.wav', format="wav") #Exports to a wav file in the current path.
    transcript=text.strip().replace(" ", "|")
    transcript = re.sub(r'[^\w|\s]', '', transcript)
    transcript = re.sub(r"(\d+)", lambda x: num2words.num2words(int(x.group(0))), transcript)
    print(segment["start"])
    subs = force_align(str(i)+'.wav', transcript.upper(), start_index, segment["start"])
    # 末尾にseparatorの音素を追加する(あとでカナ単位にスプリットするため)
    subs.append({
        "text": "|", "start": subs[-1]["end"], "end":subs[-1]["end"]
    })
    start_index += len(segment["text"])
    total_subs.extend(subs)
print(total_subs)
# この時点ではファイル出力しないのでコメントアウト
#CAPTION_FILE = open("caption.srt", "w")
#CAPTION_FILE.write(compose(total_subs))
#CAPTION_FILE.close()

total_subsの中身を確認します。

print(total_subs[0])
{'text': 'S', 'start': datetime.timedelta(seconds=1, microseconds=585250), 'end': datetime.timedelta(seconds=1, microseconds=625375)}

タイムスタンプがdatetime.timedelta型になっているので、int(秒)に変更します。

phonemes = []
for sub in total_subs:
  start = sub["start"].total_seconds()
  end = sub["end"].total_seconds()
  phonemes.append({
      "text": sub["text"]
      ,"start":start
      , "end": end
  })
print(phonemes[0])
{'text': 'S', 'start': 1.58525, 'end': 1.625375}

確認のために、音素をカタカナの単位でまとめます。

chars = []
text, start, end = "", -1, -1
for i,p in enumerate(phonemes):
  if text == "" and p["text"] != "|":
    text, start, end = p["text"], p["start"], p["end"]
    continue
  if p["text"] == "|" or i == len(phonemes)-1:
    chars.append({
        "text":text
        , "kana": romkan.to_katakana(text)
        , "start": start
        , "end":end
    })
    text, start, end = "", -1, -1
    continue
  text += p["text"]
  end = p["end"]

print(chars[0])
{'text': 'SHI', 'kana': 'シ', 'start': 1.58525, 'end': 1.806}

カナの単位でsrtファイルを作成します。

# srtファイルの各要素を作成
srtrows = []
for i, seg in enumerate(chars):
  startmsec, endmsec = int(seg["start"]*1000), int(seg["end"]*1000)
  start_hour = startmsec//3600000
  start_minute = startmsec%3600000 // 60000
  start_second = startmsec%60000 // 1000
  start_msec = startmsec%1000
  end_hour = endmsec//3600000
  end_minute = endmsec%3600000//60000
  end_second = endmsec%60000//1000
  end_msec = endmsec%1000

  #以下みたいな形式にする
  """
  1
  00:00:00,000 --> 00:00:02,000
  ネコ1 ふむ
 
  2
  00:00:02,000 --> 00:00:04,000
  ネコ2 むー
  """
  row = "{}\n{}:{}:{},{} --> {}:{}:{},{}\n{}".format(i+1, str(start_hour).zfill(2), str(start_minute).zfill(2), str(start_second).zfill(2), str(start_msec).zfill(3), str(end_hour).zfill(2), str(end_minute).zfill(2), str(end_second).zfill(2), str(end_msec).zfill(3), seg["kana"])
  srtrows.append(row)

with open("chars.srt", "w") as f:
  f.write("\n\n".join(srtrows))

最後にchars.srtを動画プレイヤーで読み込んで、alignmentの精度を目視で確認します。
主観ですが以下のような結果でした。

  • うまくいくsegmentといかないsegmentがある。全体的な精度は5割くらい?
  • segmentのタイムスタンプ検出に失敗している(広く取りすぎている)区間は、音素alignmentも失敗している。BGMだけの区間にもまんべんなくalignmentされてしまっている。

おわりに

精度はまだまだですが、一応forced-alignmentらしきものができたので、ここで区切りとします。
今後やってみたいことは以下です。歌唱区間検出の難しさと、wav2vecによる音素アライメントの難しさがネックになっていそうなので、そのあたりを改善できればと思っています。

【歌唱区間検出の精度を上げる】

  • whisperをfine-tuningする(参考)
    • 教師データを集めるのが大変そうだけど音楽データでfine-tuningできると歌唱区間の精度が上る可能性はある? でもJPOPもwhisperの学習データに入っていそうだからそう考えるとfine-tuningの意味は薄い? よくわからない。
  • BGMの除去を試みる
    • おそらくlyric-to-audio alignment的にはこれが正しいアプローチ?
  • 歌唱区間検出用のライブラリを探す(whisperを諦める)

【音素アライメントの制度を上げる】

  • wav2vecを日本語にfine-tuningする(参考
    • データは日本語の非歌唱発話音声でとりあえずやってみる。だめならwhisperを駆使して歌唱音声の教師データを作ってやってみる? 意味があるかは不明。
  • whisperでwordレベルのalignmentまでやってからwav2vecに入力する
    • 参考1参考2:whisperによるwordレベルalignmentの実装例(不安定らしい)
    • 参考3:タイムスタンプを取り出せるとは書いていないがこっちも参考になりそう?
  • whisperの推論出力をforced-alignmentの入力にする
    • できるかわからないけど、force_align関数のemissionに相当する行列をwhisperの推論出力から作れたらよさそう。

【その他】

  • 精度検証のため手動alignmentしたjpopを用意したい。
  • transcribeのオプションについて調べる(オプションの指定で一部改善することがひょっとしたらあるかも?)

前途多難ですが、時間を見つけて引き続きトライしてみたいと思います。

10
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
9