4
6

More than 1 year has passed since last update.

【試行錯誤】OpenAI Whisperを活用した日本語歌詞のforced-alignment その7:最初から最後まで処理をつなげる

Last updated at Posted at 2022-10-31

概要

「その6」までの知見を結集して、音源分離からalignmentまでの処理をワンパス通してみました。

シリーズ一覧は以下
【試行錯誤】OpenAI Whisperを活用した日本語歌詞のforced-alignment リンクまとめ

背景

「その6」までで以下の要素実装に取り組んできました。

  • meducsによるボーカル抽出
  • Whisperによる歌詞認識
  • inaSpeechSegmenterによるWhisperタイムスタンプの補正
  • 認識歌詞と正解歌詞の対応付け
  • forced-alignment

今回はこれらの処理をすべてつなげてforced-alignmentのワンパスを通してみます。

方針

各処理には時間がかかるものがあったり、人による微修正を加えたくなる場合があるので、それぞれの段階での出力をファイル出力しつつ、再利用できるような構成にします。

入力はオーディオファイルのパスを最低限必要とするものとし、正解歌詞が別途ある場合はそれも使えるようにします。正解歌詞がない場合はWhisperの認識結果を用います。

出力は、モウラ単位の歌詞の始点・終点のタイムスタンプとします。また字幕化するときに便利なように各モウラと対応する表層系の単語もわかるようにしておきます。

meducsによるボーカル抽出

「その2」ではコマンドラインから実行しましたが、ワンパスに組み込むためにコード内で実行できるようにします。

プロセス起動により実行するコードを書いてくださっているかたがいたので、ほぼそのままお借りしています。以下の点を少し変えています。

  • 他の処理特別しやすいようにクラスメソッド化
  • オプションをseparateの引数で指定できるように
  • separateの戻り値として出力先のフォルダパスを返す
#@title Useful functions, don't forget to execute
from genericpath import isfile
import io
from pathlib import Path
import os
import select
from shutil import rmtree
import subprocess as sp
import sys
from typing import Dict, Tuple, Optional, IO

class Meducs:
    # Customize the following options!
    model = "mdx_extra_q"
    extensions = ["mp3", "wav", "ogg", "flac"]  # we will look for all those file types.
    #two_stems = None   # only separate one stems from the rest, for instance
    two_stems = "vocals"

    # Options for the output audio.
    mp3 = False
    mp3_rate = 320
    float32 = False  # output as float 32 wavs, unsused if 'mp3' is True.
    int24 = False    # output as int24 wavs, unused if 'mp3' is True.
    # You cannot set both `float32 = True` and `int24 = True` !!

    in_path = './demucs/'
    out_path = './demucs_separated/'

    def __init__(self, *
                , model=None
                , extentions=None
                , two_stems=None 
                , mp3=None, mp3_rate=None 
                , float32=None, int24=None 
                , in_path=None, out_path=None
                ):
        self.model = model or self.model
        self.extensions = extentions or self.extensions
        self.two_stems = two_stems or self.two_stems
        self.mp3 = mp3 or self.mp3
        self.mp3_rate = mp3_rate or self.mp3_rate
        self.float32 = float32 or self.float32
        self.in_path = in_path or self.in_path
        self.out_path = out_path or self.out_path

    def find_files(self, in_path):
        out = []
        for file in Path(in_path).iterdir():
            if file.suffix.lower().lstrip(".") in self.extensions:
                out.append(file)
        return out

    @staticmethod
    def copy_process_streams(process: sp.Popen):
        def raw(stream: Optional[IO[bytes]]) -> IO[bytes]:
            assert stream is not None
            if isinstance(stream, io.BufferedIOBase):
                stream = stream.raw
            return stream

        p_stdout, p_stderr = raw(process.stdout), raw(process.stderr)
        stream_by_fd: Dict[int, Tuple[IO[bytes], io.StringIO, IO[str]]] = {
            p_stdout.fileno(): (p_stdout, sys.stdout),
            p_stderr.fileno(): (p_stderr, sys.stderr),
        }
        fds = list(stream_by_fd.keys())

        while fds:
            # `select` syscall will wait until one of the file descriptors has content.
            ready, _, _ = select.select(fds, [], [])
            for fd in ready:
                p_stream, std = stream_by_fd[fd]
                raw_buf = p_stream.read(2 ** 16)
                if not raw_buf:
                    fds.remove(fd)
                    continue
                buf = raw_buf.decode()
                std.write(buf)
                std.flush()

    def separate(self, inp=None, outp=None, *
                , model=None
                , two_stems=None 
                , mp3=None, mp3_rate=None 
                , float32=None, int24=None):
        inp = inp or self.in_path
        outp = outp or self.out_path
        model = model or self.model
        two_stems = two_stems or self.two_stems
        mp3 = mp3 or self.mp3
        mp3_rate = mp3_rate or self.mp3_rate
        float32 = float32 or self.float32
        int24 = int24 or self.int24

        cmd = ["python3", "-m", "demucs.separate", "-o", str(outp), "-n", model]
        if mp3:
            cmd += ["--mp3", f"--mp3-bitrate={mp3_rate}"]
        if float32:
            cmd += ["--float32"]
        if int24:
            cmd += ["--int24"]
        if two_stems is not None:
            cmd += [f"--two-stems={two_stems}"]
        p_inp = Path(inp)
        if p_inp.is_file() and p_inp.exist() and p_inp.suffix.lstrip(".").lower() in self.extensions:
            files = [inp]
        else:
            print(f"No valid audio files in {inp}")
            return
        print("Going to separate the files:")
        print('\n'.join(files))
        print("With command: ", " ".join(cmd))
        #p = sp.Popen(cmd + files, stdout=sp.PIPE, stderr=sp.PIPE)
        #self.copy_process_streams(p)
        #p.wait()
        #if p.returncode != 0:
        #    print("Command failed, something went wrong.")

        return str(Path(outp).joinpath(model, p_inp.stem))

単体で使うときは以下のような感じです。

meducs = Meducs(model="mdx_q")
meducs.separate("audio/yorunikakeru.wav", "output/separated/")

音声認識(Whisper, inaSpeechSegmenter)

「その3」でinaSpeechSegmenterによりWhisperのタイムスタンプを補正するアイデアを実装しました。この処理をWhisperによる音声認識とまとめて実装します。
ポイントは以下です。

  • Whisperのモデルを指定してinit
  • Whisperによる認識とタイムスタンプの補正処理を実装
  • Whisperの認識結果がすでに得られているときはWhisperモデルを使わないことで、ロード時間を短縮できる
import whisper
import inaSpeechSegmenter
import json

class SoundRecognition:
  def __init__(self, model = None
              , *
              , use_whisper = True
              ):

    if use_whisper:
      self.model = model or whisper.model.load()
    else:
      self.model = None

    self.segmenter = inaSpeechSegmenter.Segmenter(vad_engine='sm', detect_gender=False)

  def load_whisper_result(self, whisper_result_path):
    with open(whisper_result_path) as f:
      whisper_result = json.load(f)
    segments = [{k:v[k] for k in ["text", "start", "end"]} for v in whisper_result["segments"]]
    return segments    
    
  def recognize_by_whisper(self, audio_file_path):
    result = self.model.transcribe(audio_file_path)
    # 必要な情報(text, start, end)だけを抽出
    segments = [{k:v[k] for k in ["text", "start", "end"]} for v in result["segments"]]
    return segments

  def recognize_no_energy(self, audio_file_path):
    result = self.segmenter(audio_file_path)
    no_energy = [(start, end) for label, start, end in result if label == "noEnergy"]
    return no_energy
  
  def update_whisper_timestamp(self, whisper_segments, no_energy):
    updated = []

    for whisper_segment in whisper_segments:
      whisper_start, whisper_end = whisper_segment["start"], whisper_segment["end"]
      
      # no_energyがwhisperを包含するとき、updatedに含めない
      no_energy_includes_whisper = False
      for no_energy_start, no_energy_end in no_energy:
        # no_energyがwhisperを包含
        if no_energy_start <= whisper_start and whisper_end <= no_energy_end:
          no_energy_includes_whisper = True
          break
        # no_energyがwhisperより完全に未来
        if whisper_end <= no_energy_start:
          break
      if no_energy_includes_whisper:
        # updatedへの追加をスキップ
        continue
      
      # whisperのstart, endをinaで更新
      updated_seg = whisper_segment.copy()
      for no_energy_start, no_energy_end in no_energy:
        # inaがwhisperより完全に未来
        if whisper_end <= no_energy_start:
          break
        # whisperの区間が完全に前のとき(区間のかぶりがないとき)、continue
        elif whisper_end <= no_energy_start:
          pass
        # ina_endがwhisper_startよりも遅いときwhisper_startをina_endに合わせる
        elif whisper_start <= no_energy_end and no_energy_end <= whisper_end :
          updated_seg["start"] = no_energy_end
        # ina_startがwhisper_endよりも遅いときwhisper_endをina_startに合わせる
        elif no_energy_start <= whisper_end and whisper_start <= no_energy_start:
          updated_seg["end"] = no_energy_start
        # それ以外の関係性は想定しない(スキップ)
        else:
          pass
      updated.append(updated_seg)
    return updated

  def exec(self, audio_path):
    whisper_segments = self.recognize_by_whisper(audio_path)
    no_energy = self.recognize_no_energy(audio_path)
    updated_segments = self.update_whisper_timestamp(whisper_segments, no_energy)
    return updated_segments

単体で実行する場合は以下のような感じ

whisper_model = whisper.load_model("base")
soundrecognition = SoundRecognition(whisper_model)
soundrecognition_result = SoundRecognition.exec("audio/yorunikakeru_vocals.wav")
print(soundrecognition_result)
[
  {
    "text": "沈むように溶けてゆくように",
    "start": 1.1,
    "end": 6.0
  },
  {
    "text": "二人だけの空が広がる夜に",
    "start": 8.46,
    "end": 15.120000000000001
  },
  ...
]

正解歌詞との対応付け

「その6」で実装した正解歌詞と認識結果の対応付けを一連の処理に組み込むため、まずは漢字仮名交じり文を分かち書きしたり、ローマ字に変化するための関数を定義しておきます。

import re
import romkan
import MeCab
import jaconv
import unicodedata
import string

class Phoneme:

  mecab = MeCab.Tagger()

    
  @staticmethod
  def mora_wakachi(kana_text):   
    #各条件を正規表現で表す
    c1 = '[ウクスツヌフムユルグズヅブプヴ][ァィェォ]' #ウ段+「ァ/ィ/ェ/ォ」
    c2 = '[イキシチニヒミリギジヂビピ][ャュェョ]' #イ段(「イ」を除く)+「ャ/ュ/ェ/ョ」
    c3 = '[テデ][ィュ]' #「テ/デ」+「ャ/ィ/ュ/ョ」
    c4 = '[ァ-ヴー]' #カタカナ1文字(長音含む)
    c5 = '[a-zA-Z]+' #念の為アルファベットも抽出できるように

    condition = '('+c1+'|'+c2+'|'+c3+'|'+c4+'|'+c5+')'
    return re.findall(condition, kana_text)

  @staticmethod
  def kana_to_romaji(kana_list):
    romaji_org = [romkan.to_roma(kana) for kana in kana_list]
    romaji_fixed = []
    for i, roma in enumerate(romaji_org):
      # 「ッ」のとき
      if roma == "xtsu":
        # 末尾ならtにする
        if i == len(romaji_org)-1:
          roma = "t"
        # 末尾以外なら次の要素の1文字目にする
        else:
          roma = romaji_org[i+1][0]
      # 「ー」のとき
      elif roma == "-":
        # 先頭ならnにする(なんとなく)
        if i == 0:
          roma = "n"
        # 先頭以外なら直前の要素の母音にする
        else:
          roma = romaji_org[i-1][-1]
      
      romaji_fixed.append(roma)
    return romaji_fixed

  @staticmethod
  def format_text(text):
    text = unicodedata.normalize("NFKC", text)  # 全角記号をざっくり半角へ置換(でも不完全)

    # 記号を消し去るための魔法のテーブル作成
    table = str.maketrans("", "", string.punctuation  + "「」、。・")
    text = text.translate(table)

    return text

  @classmethod
  def get_surface_and_pronunciation(cls, text):
    m_result = cls.mecab.parse(text).splitlines() #mecabの解析結果の取得
    m_result = m_result[:-1] #最後の1行は不要な行なので除く
    
    pronunciations = [] #発音文字列全体を格納する変数
    surfaces = []
    for v in m_result:
      if '\t' not in v: continue
      surface = v.split('\t')[0] #表層形
      pronunciation = v.split('\t')[1].split(',')[-1] #発音を取得したいとき
      #p = v.split('\t')[1].split(',')[-2] #ルビを取得したいとき
      #発音が取得できていないときsurfaceで代用
      if pronunciation == '*': pronunciation = surface.upper()

      pronunciation = jaconv.hira2kata(pronunciation) #ひらがなをカタカナに変換
      pronunciation = cls.format_text(pronunciation) #余計な記号を削除

      surfaces.append(surface)
      pronunciations.append(pronunciation)
    
    return surfaces, pronunciations

  @classmethod
  def get_pronunciation(cls, text):
    _, pronunciations = cls.get_surface_and_pronunciation(text)
    return pronunciations

単体で使ってみると以下のような感じです。

print(Phoneme.kana_to_romaji(Phoneme.mora_wakachi("ガッキューabcホーカイ")))
['ga', 'k', 'kyu', 'u', 'abc', 'ho', 'o', 'ka', 'i']

正解文字列と認識結果の対応付けをするクラスを作ります。
また「その6」では正解文字列から形態素解析ライブラリで推測した発音をalignmentに使用していますが、推測した読みは間違いが含まれることも多いです。そのため、発音(カナ)の正解データもあれば、それを使えたほうが良いです。この場合、正解文字列の表層形と発音(カナ)の対応付けも必要になります。対応づいたcsv的なファイルを人間が用意してもよいのですが、利便性を考えて、漢字かな交じりのテキスト(表層形)とカナのみのテキスト(発音)から単語レベルの対応付けも自動で行えるようにします。

つまり以下の2種類の対応付けをおこないます。

  • 正解文字列(表層形)と推測フレーズ(表層形)の対応付け
  • 正解文字列(表層形)と正解文字列(カナ発音)の対応付け

実装は以下のような感じです。

# 編集距離と対応のリストを返す
import editdistance as ed

class Allocater:

  def __init__(self):
    pass
  def exec(self, estimated_phrases, correct_surface_text, correct_kana_text):
    # correct_surface_text: strからcorrect_words: list[str]を得る
    # estimated_phrases: list[str]からestimated_phrase_words: list[list[str]]を得る
    # correct_wordsとestimated_phrase_wordsからcorrect_phrase_words: list[list[str]]を得る
    # correct_phrase_words: list[list[str]]からestimated_phrase_word_moras: list[list[list[str]]]を得る
    # correct_kana_text: strからcorrect_moras: list[str]を得る
    # correct_morasとestimated_phrase_word_morasからcorrect_phrase_word_morasを得る

    # correct_surface_text: strからcorrect_words: list[str]を得る
    correct_words, estimated_word_kanas = Phoneme.get_surface_and_pronunciation(correct_surface_text)
    correct_words = tuple(correct_words)
    # estimated_phrases: list[str]からestimated_phrase_words: list[list[str]]を得る
    estimated_phrase_words = [tuple(Phoneme.get_surface_and_pronunciation(v)[0]) for v in estimated_phrases]
    # correct_wordsとestimated_phrase_wordsからcorrect_phrase_words: list[list[str]]を得る
    dist, phrase_correspondance = self.find_correspondance(correct_words, estimated_phrase_words)
    correct_phrase_words = [correct_words[start:end] for start, end in phrase_correspondance]
    # correct_phrase_words: list[list[str]]からestimated_phrase_word_moras: list[list[list[str]]]を得る
    estimated_word_moras = [tuple(Phoneme.mora_wakachi(v)) for v in estimated_word_kanas]
    # correct_kana_text: strからcorrect_moras: list[str]を得る
    correct_moras = tuple(Phoneme.mora_wakachi(correct_kana_text))
    # correct_morasとestimated_phrase_word_morasからcorrect_phrase_word_morasを得る
    dist, word_correspondance = self.find_correspondance(correct_moras, estimated_word_moras)
    correct_word_moras = [correct_moras[start:end] for start, end in word_correspondance]
    correct_phrase_word_moras = [correct_word_moras[start:end] for start, end in phrase_correspondance]
    #print(correct_phrase_word_moras)
    return correct_phrase_words, correct_phrase_word_moras
    
  # 入力: correct_textはタプル、test_segmentsはcorrect_textより1つ次元の多いタプル。correct_textはstrでも可能
  # 出力は分割のindexとその分割をした場合の編集距離
  @staticmethod
  def find_correspondance(correct_text, test_segments):
    memo = {}
    def inner_func(correct_text, test_segments):
      memo_key = (correct_text, tuple(test_segments))
      if memo_key in memo:
        return memo[memo_key]

      # 特殊ケースの対応
      if correct_text and not test_segments:
        return len(correct_text), []
      elif not correct_text and test_segments:
        flatten_test_segments = [x for row in test_segments for x in row]
        result = (len(flatten_test_segments), [(0,0) for i in range(len(test_segments))])
        memo[memo_key] = result
        return result
      elif not correct_text and not test_segments:
        return 0, []
      # test_segmentが最後一つのとき、全部を対応させる
      elif correct_text and len(test_segments) == 1:
        dist = ed.eval(correct_text, test_segments[0])
        memo[memo_key] = (dist, [(0, len(correct_text))])
        return dist, [(0, len(correct_text))]
      
      # 全体の編集距離がゼロなら先頭から順番に対応付けすれば良い
      flatten_test_segments = tuple([x for row in test_segments for x in row])
      if correct_text  == flatten_test_segments:
        correspondance = []
        cnt = 0
        for seg in test_segments:
          correspondance.append((cnt, cnt+len(seg)))
          cnt += len(seg)
        memo[memo_key] = (0, correspondance)
        return 0, correspondance

      # プラスマイナスwindow_sizeの幅で最適な対応をみつける
      text = test_segments[0]
        
      results = []
      #window_size = ed.eval(correct_text, "".join(test_segments))
      window_size = 5
      for i in range(2*window_size+1):
        diff = i-window_size
        if len(text) + diff < 0: continue
        head_dist = ed.eval(correct_text[0:len(text)+diff], text)
        head_correspondance = [(0, len(text)+diff)]
        tail_dist, tail_correspondance = inner_func(correct_text[len(text)+diff:], test_segments[1:])
        # indexを最初の対応の長さで補正
        tail_correspondance = [(s+len(text)+diff, e+len(text)+diff) for s,e in tail_correspondance]

        dist = head_dist+tail_dist
        correspondance = head_correspondance + tail_correspondance
        results.append((dist, correspondance))
      #print(min(results, key=lambda x: x[0]))
      min_result = min(results, key=lambda x: x[0])
      memo[memo_key] = min_result
      return min_result
    # correct_textはtupleとして扱う
    if type(correct_text) is str:
      correct_text = tuple(correct_text)
    return inner_func(correct_text, test_segments)

  # デバッグ・確認用。correspondance(始点終点のindex)を文字列のペアになおして見やすくする
  def display_correspondance(correct_text, test_segments, correspondance):
    for test_seg, (start, end) in zip(test_segments, correspondance):
      print("test:", test_seg)
      print("correct:", correct_text[start:end])
      print("")

ポイントは以下です。

  • 計算量削減のため、単語単位で処理
  • 2種類の対応付けを1つの関数内で実行した際の戻り値として何を返すかが意外と悩ましいのですが、phraseの切れ目、wordの切れ目、各wordに対応する

単体で使ってみると以下のような感じです。

correct_surface_text = "静むように溶けてゆくように二人だけの空が広がる夜に「さよなら」だけだったその一言で全てが分かった日が沈み出した空と君の姿フェンス越しに重なっていた初めて会った日から僕の心の全てを奪ったどこか儚い空気を纏う君は寂しい目をしてたんだ"
correct_kana_text = "シズムヨーニトケテユクヨーニフタリダケノソラガヒロガルヨルニサヨナラダケダッタソノヒトコトデスベテガワカッタヒガシズミダシタソラトキミノスガタフェンスゴシニカサナッテイタハジメテアッタヒカラボクノココロノスベテオウバッタドコカハカナイクウキオマトウキミハサミシイメオシテタンダ"
estimated_phrases = [
 "静むように溶けてゆくように"
, "二人だけの空が白がる夜に"
, "さよなら駆け合ったその一言で全てが分かった"
, "東姫出した空と君の姿 ケウスをしに重なってた"
, "初めてあったしから 僕の心の全てを奪った"
, "どこかはかない空気をなとう君は寂しい目をしてたんだ"
]

allocater = Allocater()
surfaces, moras = allocater.exec(estimated_phrases, correct_surface_text, correct_kana_text)
for p, s, m in zip(estimated_phrases,surfaces,moras):
  print(p)
  print(s)
  print(m)
  print("")

静むように溶けてゆくように
('静', 'むよう', 'に', '溶け', 'て', 'ゆく', 'よう', 'に')
[('シ', 'ズ'), ('ム', 'ヨ', 'ー'), ('ニ',), ('ト', 'ケ'), ('テ',), ('ユ', 'ク'), ('ヨ', 'ー'), ('ニ',)]

二人だけの空が白がる夜に
('二人', 'だけ', 'の', '空', 'が', '広がる', '夜', 'に')
[('フ', 'タ', 'リ'), ('ダ', 'ケ'), ('ノ',), ('ソ', 'ラ'), ('ガ',), ('ヒ', 'ロ', 'ガ', 'ル'), ('ヨ', 'ル'), ('ニ',)]

さよなら駆け合ったその一言で全てが分かった
('「', 'さよなら', '」', 'だけ', 'だっ', 'た', 'その', '一言', 'で', '全て', 'が', '分かっ', 'た')
[(), ('サ', 'ヨ', 'ナ', 'ラ'), (), ('ダ', 'ケ'), ('ダ', 'ッ'), ('タ',), ('ソ', 'ノ'), ('ヒ', 'ト', 'コ', 'ト'), ('デ',), ('ス', 'ベ', 'テ'), ('ガ',), ('ワ', 'カ', 'ッ'), ('タ',)]

東姫出した空と君の姿 ケウスをしに重なってた
('日', 'が', '沈み', '出し', 'た', '空', 'と', '君', 'の', '姿', 'フェンス', '越し', 'に', '重なっ', 'て', 'い', 'た')
[('ヒ',), ('ガ',), ('シ', 'ズ', 'ミ'), ('ダ', 'シ'), ('タ',), ('ソ', 'ラ'), ('ト',), ('キ', 'ミ'), ('ノ',), ('ス', 'ガ', 'タ'), ('フェ', 'ン', 'ス'), ('ゴ', 'シ'), ('ニ',), ('カ', 'サ', 'ナ', 'ッ'), ('テ',), ('イ',), ('タ',)]

初めてあったしから 僕の心の全てを奪った
('初めて', '会っ', 'た', '日', 'から', '僕', 'の', '心', 'の', '全て', 'を', '奪っ', 'た')
[('ハ', 'ジ', 'メ', 'テ'), ('ア', 'ッ'), ('タ',), ('ヒ',), ('カ', 'ラ'), ('ボ', 'ク'), ('ノ',), ('コ', 'コ', 'ロ'), ('ノ',), ('ス', 'ベ', 'テ'), ('オ',), ('ウ', 'バ', 'ッ'), ('タ',)]

どこかはかない空気をなとう君は寂しい目をしてたんだ
('どこか', '儚い', '空気', 'を', '纏う', '君', 'は', '寂しい', '目', 'を', 'し', 'て', 'た', 'ん', 'だ')
[('ド', 'コ', 'カ'), ('ハ', 'カ', 'ナ', 'イ'), ('ク', 'ウ', 'キ'), ('オ',), ('マ', 'ト', 'ウ'), ('キ', 'ミ'), ('ハ',), ('サ', 'ミ', 'シ', 'イ'), ('メ',), ('オ',), ('シ',), ('テ',), ('タ',), ('ン',), ('ダ',)]

forced-alignment

ようやくforced-alignmentです。「その1」とほぼ同じですが、使い回しやすいように、グローバル部分に書いていた処理を関数化したうえで、全体をクラス化しています。

import tensorflow as tf
import torch
import torchaudio
from datetime import timedelta
from dataclasses import dataclass
from pydub import AudioSegment
import re
import num2words

class Alignment:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.cuda.empty_cache()

    torch.random.manual_seed(0)

    def force_align(self, SPEECH_FILE, transcript, start_index, start_time):
        bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
        model = bundle.get_model().to(self.device)
        labels = bundle.get_labels()
        with torch.inference_mode():
            waveform, waveform_sample_rate = torchaudio.load(SPEECH_FILE)
            # waveformのサンプルレートをbundleに合わせる
            waveform = torchaudio.transforms.Resample(waveform_sample_rate, bundle.sample_rate)(waveform)
            waveform_sample_rate = bundle.sample_rate
            emissions, _ = model(waveform.to(self.device))
            emissions = torch.log_softmax(emissions, dim=-1)

        emission = emissions[0].cpu().detach()

        dictionary = {c: i for i, c in enumerate(labels)}

        tokens = [dictionary[c] for c in transcript]

        def get_trellis(emission, tokens, blank_id=0):
            num_frame = emission.size(0)
            num_tokens = len(tokens)

            # Trellis has extra diemsions for both time axis and tokens.
            # The extra dim for tokens represents <SoS> (start-of-sentence)
            # The extra dim for time axis is for simplification of the code.
            trellis = torch.empty((num_frame + 1, num_tokens + 1))
            trellis[0, 0] = 0
            trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
            trellis[0, -num_tokens:] = -float("inf")
            trellis[-num_tokens:, 0] = float("inf")

            for t in range(num_frame):
                trellis[t + 1, 1:] = torch.maximum(
                    # Score for staying at the same token
                    trellis[t, 1:] + emission[t, blank_id],
                    # Score for changing to the next token
                    trellis[t, :-1] + emission[t, tokens],
                )
            return trellis


        trellis = get_trellis(emission, tokens)

        @dataclass
        class Point:
            token_index: int
            time_index: int
            score: float


        def backtrack(trellis, emission, tokens, blank_id=0):
            # Note:
            # j and t are indices for trellis, which has extra dimensions
            # for time and tokens at the beginning.
            # When referring to time frame index `T` in trellis,
            # the corresponding index in emission is `T-1`.
            # Similarly, when referring to token index `J` in trellis,
            # the corresponding index in transcript is `J-1`.
            j = trellis.size(1) - 1
            t_start = torch.argmax(trellis[:, j]).item()

            path = []
            for t in range(t_start, 0, -1):
                # 1. Figure out if the current position was stay or change
                # Note (again):
                # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
                # Score for token staying the same from time frame J-1 to T.
                stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
                # Score for token changing from C-1 at T-1 to J at T.
                changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

                # 2. Store the path with frame-wise probability.
                prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
                # Return token index and time index in non-trellis coordinate.
                path.append(Point(j - 1, t - 1, prob))

                # 3. Update the token
                if changed > stayed:
                    j -= 1
                    if j == 0:
                        break
            else:
                raise ValueError("Failed to align")
            return path[::-1]


        path = backtrack(trellis, emission, tokens)

        # Merge the labels
        @dataclass
        class Segment:
            label: str
            start: int
            end: int
            score: float

            def __repr__(self):
                return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

            @property
            def length(self):
                return self.end - self.start


        def merge_repeats(path):
            i1, i2 = 0, 0
            segments = []
            while i1 < len(path):
                while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                    i2 += 1
                score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
                segments.append(
                    Segment(
                        transcript[path[i1].token_index],
                        path[i1].time_index,
                        path[i2 - 1].time_index + 1,
                        score,
                    )
                )
                i1 = i2
            return segments

        # wordレベルのalignmentは不要なのでmerge_words関数は削除

        segments = merge_repeats(path)
        subs = []
        for i,word in enumerate(segments):
            ratio = waveform.size(1) / (trellis.size(0) - 1)
            x0 = int(ratio * word.start)
            x1 = int(ratio * word.end)
            # wav2vec(bundle)ではなく入力のsample_rateで割る
            #start = timedelta(seconds=start_time + x0 / bundle.sample_rate)
            #end = timedelta(seconds=start_time + x1 / bundle.sample_rate )
            start = timedelta(seconds=start_time + x0 / waveform_sample_rate).total_seconds()
            end = timedelta(seconds=start_time + x1 / waveform_sample_rate ).total_seconds()
            # subsの要素はstr形式ではなく、dict形式にする。この後関数外でmergeするので
            #subtitle = Subtitle(start_index+i, start, end, word.label)
            #subs.append(subtitle) 
            subs.append({
                "text": word.label
                , "start": start
                , "end":end
            })
        return subs

    def align_by_phrase(self, audio_file_path, starts, ends, texts):
        print("Starting to force alignment...")
        separator = "|"
        start_index = 0
        total_subs = []
        for i, (start, end, text) in enumerate(zip(starts, ends, texts)):
            audioSegment = AudioSegment.from_wav(audio_file_path)[int(start*1000):int(end*1000)]

            audioSegment.export("output/tmp/"+str(i)+'.wav', format="wav") #Exports to a wav file in the current path.
            transcript=text.strip().replace(" ", separator)
            transcript = re.sub(r'[^\w|\s]', '', transcript)
            transcript = re.sub(r"(\d+)", lambda x: num2words.num2words(int(x.group(0))), transcript)
            print(start)
            subs = self.force_align("output/tmp/"+str(i)+'.wav', transcript.upper(), start_index, start)
            # 末尾にseparatorの音素を追加する(あとでカナ単位にスプリットするため)
            subs.append({
                "text": separator, "start": subs[-1]["end"], "end":subs[-1]["end"]
            })
            start_index += len(text)
            total_subs.extend(subs)
        #print(total_subs)
        
        # separatorの単位でmerge
        chars = []
        text, start, end = "", -1, -1
        for i,p in enumerate(total_subs):
            if text == "" and p["text"] != separator:
                text, start, end = p["text"], p["start"], p["end"]
                continue
            if p["text"] == separator or i == len(total_subs)-1:
                chars.append({
                    "text":text
                    , "start": start
                    , "end":end
                })
                text, start, end = "", -1, -1
                continue
            text += p["text"]
            end = p["end"]
        return chars

単体で使うと以下のような感じです。(テストなので)「その5」でアノテーションした正解データを入力にしています。

SPEECH_FILE = "audio/yorunikakeru_vocals.wav"
df = pd.read_csv("output/annotation/roma.csv")

alignment = Alignment()
chars = alignment.align_by_phrase(SPEECH_FILE, df["start"], df["end"], df["roma"])
print(chars)
[
  {
    "text": "SHI",
    "start": 1.602,
    "end": 1.822938
  },
  {
    "text": "ZU",
    "start": 1.843,
    "end": 2.02375
  },
  ...
]

処理を順番に実行する

要素実装が完了したので、処理を順番に実行する関数を作ります。

from pathlib import Path
import editdistance as ed
import json
import pandas as pd


def align(*
          , audio_path = None
          , vocal_path = None
          , lyric_path = None
          , kana_path = None
          , whisper_result_path = None
          , soundrecognition_path = None
          , alignment_path = None
          , output_path = None
          ):
  if not audio_path and not vocal_path:
    print("audio_path or vocal_path should exist")
    return
  
  if not output_path:
    output_path = Path("output").joinpath(Path(vocal_path or audio_path).stem)
    output_path.mkdir(exist_ok=True)
    output_path = str(output_path)

  print("Start separating vocals...")
  if not vocal_path:
    meducs = Meducs()
    vocal_path = Path(meducs.separate(audio_path, outp=output_path)).joinpath("vocals.wav")

  print("Start sound recognition...")
  if not whisper_result_path and not soundrecognition_path:
    whisper_model = whisper.load_model("base")
    soundrecognition = SoundRecognition(whisper_model)
    whisper_result = SoundRecognition.recognize_by_whisper(vocal_path)
    whisper_result_path = str(Path(output_path).joinpath("whisper_result_path.json"))
    with open(whisper_result_path) as f:
      json.dump(whisper_result, f, indent = 2, ensuire_ascii=False)
  
  if whisper_result_path and not soundrecognition_path:
    soundrecognition = SoundRecognition(use_whisper=False)
    whisper_result = soundrecognition.load_whisper_result(whisper_result_path)
    no_energy_result = soundrecognition.recognize_no_energy(vocal_path)
    soundrecognition_result = soundrecognition.update_whisper_timestamp(whisper_result, no_energy_result)
    soundrecognition_path = str(Path(output_path).joinpath("soundrecognition_result.json"))
    with open(soundrecognition_path, "w") as f:
      json.dump(soundrecognition_result, f, indent=2, ensure_ascii=False)
  
  with open(soundrecognition_path) as f:
    soundrecognition_result = json.load(f)

  print("Start calculating correspondance between recognized text and correct text") 
  # lyric_pathがなければ音声認識結果をもとに作成
  if not lyric_path:
    lyrics = [v["text"] for v in soundrecognition_result]
    lyric_path = str(Path(output_path).joinpath("lyric.txt"))
    with open(lyric_path, "w") as f:
      f.write("\n".join(lyrics))

  with open(lyric_path) as f:
    correct_surface_text = " ".join(f.read().splitlines())

  # kana_pathがなければ、correct_surface_textに基づいて作成
  if not kana_path:
    kana_path = str(Path(output_path).joinpath("correct_kana.txt"))
    correct_kana_text = "".join(Phoneme.get_pronunciation(correct_surface_text))
    with open(kana_path,"w") as f:
      f.write(correct_kana_text)
  
  with open(kana_path) as f:
    correct_kana_text = "".join(f.read().splitlines())
  
  estimated_phrases = [v["text"] for v in soundrecognition_result]

  allocater = Allocater()
  correct_phrase_words, correct_phrase_word_moras = allocater.exec(estimated_phrases, correct_surface_text, correct_kana_text)
  print("Start forced-alignment...")  
  align_info = []
  for phrase, segment in zip(correct_phrase_word_moras, soundrecognition_result):
    start, end = segment["start"], segment["end"]
    moras = [mora for word in phrase for mora in word]
    romas = Phoneme.kana_to_romaji(moras)
    align_info.append({
      "start": start
      , "end": end
      , "text": " ".join(moras)
      , "roma":  " ".join(romas)
    })
    
  align_input_df = pd.DataFrame(align_info)

  if not alignment_path:
    alignment = Alignment()
    align_result = alignment.align_by_phrase(vocal_path, align_input_df["start"], align_input_df["end"], align_input_df["roma"])
    alighment_path = str(Path(output_path).joinpath("alighment_result.json"))
    with open(alighment_path, "w") as f:
      json.dump(align_result, f, indent=2, ensure_ascii=False)
  
  with open(alignment_path) as f:
    align_result = json.load(f)
  correct_moras = [mora for phrase in correct_phrase_word_moras for word in phrase for mora in word]
  correct_words = [word for phrase in correct_phrase_words for word in phrase]
  word_ids = []
  for i, word in enumerate([word for phrase in correct_phrase_word_moras for word in phrase]):
    word_ids += [i]*len(word)

  align_result_df = pd.DataFrame(align_result)
  align_result_df["mora"] = correct_moras
  align_result_df["word_id"] = word_ids
  align_result_df["word_surface"] = [correct_words[i] for i in word_ids]
  
  result_path = str(Path(output_path).joinpath("result.csv"))
  align_result_df.to_csv(result_path, index=False)
  #print(align_result_df)
  return align_result_df

以下のように実行します。

align_result_df = align(vocal_path = "audio/yorunikakeru_vocals.wav"
      , soundrecognition_path = "output/yorunikakeru_vocals/soundrecognition_result.json"
      #, whisper_result_path = "soundrecognition/largeresult_yorunikakeru_vocal.json"
      , lyric_path = "lyric/yorunikakeru_lyric.txt"
      , kana_path = "output/yorunikakeru_vocals/correct_kana.txt"
      , alignment_path = "output/yorunikakeru_vocals/alighment_result.json"
      )

ポイントは以下です。

  • 最低限audio_pathがあれば、音源分離から音声認識、alignmentまで一連の処理を実装
  • 途中まで処理済みのものや人間が作成・修正したデータを引数指定で使うことができる。
    • ボーカル抽出は時間がかかるので1回やったら2回目以降は使い回すと良い
    • whisperの音声認識は時間がかかり、かつ、大げさに失敗する可能性もままある(明らかに曲の途中で途中で認識が終了するなど)ので、別処理で出力し必要に応じて微修正してから入力したほうが使いやすい

以下のような出力が得られます。

result.csv
text,start,end,mora,word_id,word_surface
SHI,1.602,1.822938,シ,0,沈む
ZU,1.843,2.02375,ズ,0,沈む
MU,2.084,2.405312,ム,0,沈む
YO,2.445438,2.586063,ヨ,1,よう
O,2.827,2.987687,ー,1,よう
...

おわりに

「その6」からやれることはほぼ増えていませんが、要素実装をモジュール化し、途中結果も使いながらワンパスで実行する関数を作れたので、だいぶ使い勝手が良くなった気がします。
課題はまだいくつかあって

  • whisperの音声認識がよく失敗する。安定させたい。あるいは失敗したことを検知してリトライする機能を作りたい。
  • 最終出力ではモウラと単語の対応が得られているが、できればもう少し細かく、モウラと表層系の各文字の対応を得ておきたい。また、記号など発音のない表層形の情報が最終結果から失われているので、前後の単語にくっつけるなどして、保持しておきたい(そうしないと字幕の再現ができなくなる)。
  • コードの整理。今はすべてnotebookで完結させているがクラス化した要素技術は別ファイルで管理するなどして、構造をよくしたい

あたりは気が向いたときに改善したいです。
ただ、「その3」に続いて2回めの区切りを迎えられた気がしますので、そろそろWebアプリ化とか、少し毛色の違うことにも取り組んでみようかと思います。

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6