0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchで微分可能なMFCCの実装

Posted at

はじめに

PyTorchで微分可能なMFCCを実装することで、様々な研究に役立つと考え、執筆いたしました。私が把握する限り、微分可能なMFCCを計算するライブラリはPyTorchにはなかったと思います。

MFCCで使用する関数、クラス

関数:メルフィルタバンク

def hz2mel(f):
    """Hzをmelに変換"""
    return 2595 * torch.log(f / 700.0 + 1.0)

def mel2hz(m):
    """melをhzに変換"""
    return 700 * (torch.exp(m / 2595) - 1.0)
    
def melFilterBank(fs, N, numChannels):
    """メルフィルタバンクを作成"""
    # ナイキスト周波数(Hz)
    fmax = fs / 2
    # ナイキスト周波数(mel)
    melmax = hz2mel(fmax)
    # 周波数インデックスの最大数
    nmax = N // 2
    # 周波数解像度(周波数インデックス1あたりのHz幅)
    df = fs / N
    # メル尺度における各フィルタの中心周波数を求める
    dmel = melmax / (numChannels + 1)
    melcenters = torch.arange(1, numChannels + 1) * dmel
    # 各フィルタの中心周波数をHzに変換
    fcenters = mel2hz(melcenters)
    # 各フィルタの中心周波数を周波数インデックスに変換
    indexcenter = torch.round(fcenters / df)
    # 各フィルタの開始位置のインデックス
    indexstart = torch.hstack((torch.tensor([0]), indexcenter[0:numChannels - 1]))
    # 各フィルタの終了位置のインデックス
    indexstop = torch.hstack((indexcenter[1:numChannels], torch.tensor([nmax])))
    filterbank = torch.zeros((numChannels, nmax))
    for c in range(0, numChannels):
        # 三角フィルタの左の直線の傾きから点を求める
        increment= 1.0 / (indexcenter[c] - indexstart[c])
        for i in range(int(indexstart[c]), int(indexcenter[c])):
            filterbank[c, i] = (i - indexstart[c]) * increment
        # 三角フィルタの右の直線の傾きから点を求める
        decrement = 1.0 / (indexstop[c] - indexcenter[c])
        for i in range(int(indexcenter[c]), int(indexstop[c])):
            filterbank[c, i] = 1.0 - ((i - indexcenter[c]) * decrement)

    return filterbank, fcenters

# メルフィルタバンクを作成
numChannels = 20  # メルフィルタバンクのチャネル数
df = rate / fft_n   # 周波数解像度(周波数インデックス1あたりのHz幅)
filterbank, fcenters = melFilterBank(torch.tensor(rate), fft_n, numChannels)

# メルフィルタバンクのプロット
for c in torch.arange(0, numChannels):
    plt.plot(torch.arange(0, fft_n / 2) * df, filterbank[c])
    
plt.title('Mel filter bank')
plt.xlabel('Frequency[Hz]')
plt.show()

クラス:離散コサイン変換

class DCT:
    def __init__(self,N):
        self.N = N	# データ数.
        
        # 1次元離散コサイン変換の変換行列を予め作っておく
        self.phi_1d = self.phi(0).unsqueeze(0)
        for i in range(1, self.N):
            self.phi_1d = torch.cat((self.phi_1d, self.phi(i).unsqueeze(0)), 0)


    def dct(self,data):
        """ 1次元離散コサイン変換を行う """
        return torch.matmul(self.phi_1d, data)

    def idct(self,c):
        """ 1次元離散コサイン逆変換を行う """
        return torch.sum(self.phi_1d.T * c, dim=1)

    def phi(self,k):
        """ 離散コサイン変換(DCT)の基底関数 """
        # DCT-II
        if k == 0:
            return torch.ones(self.N) / torch.sqrt(torch.tensor(self.N))
        else:
            return torch.sqrt(torch.tensor(2.0 / self.N)) * torch.cos((k * torch.pi/(2*self.N)) * (torch.arange(self.N) * 2 + 1))

N = 10	# データ数を100とします
dct = DCT(N)	# 離散コサイン変換を行うクラスを作成
rands = torch.rand(N)	# N個の乱数データを作成
c = dct.dct(rands)					# 離散コサイン変換を実行
ic = dct.idct(c)					# 離散コサイン逆変換を実行

# 元のデータ(x)と復元したデータ(y)をグラフにしてみる
plt.plot(rands,label="original")
plt.plot(ic,label="restored")
plt.legend()
plt.title("original data and restored data")
plt.show()

関数:スプライシング

splice=0の場合、スプライシングは行われません。

def splicing(feature, splice=0):
    # splicing: 前後 n フレームの特徴量を結合する
    org_feat = feature.clone()
    for n in range(-splice, splice + 1):
        # 元々の特徴量を n フレームずらす
        tmp = torch.roll(org_feat, n, dims=0)
        if n < 0:
            # 前にずらした場合は
            # 終端nフレームを0にする
            tmp[n:] = 0
        elif n > 0:
            # 後ろにずらした場合は
            # 始端nフレームを0にする
            tmp[:n] = 0
        else:
            continue
        # ずらした特徴量を次元方向に
        # 結合する
        feature = torch.hstack((feature, tmp))

    return feature

実装

def torch_mfcc(x, 
               rate, 
               frame_size_ms = 25, 
               frame_shift_ms = 10, 
               numChannels = 20, # メルフィルタバンクのチャネル数
               nceps = 12, 
               splice = 0, 
               ):
    """_summary_

    Args:
        x (_type_): 音声波形
        rate (_type_): 周波数
        numChannels (int, optional): メルフィルタバンクのチャネル数. Defaults to 20.

    Returns:
        torch.tensor: `(num_timestep, nceps)`
    """
    frame_size_sample = int(rate * frame_size_ms * 10e-4)
    frame_shift_sample = int(rate * frame_shift_ms * 10e-4)
    
    fft_n = 1
    while fft_n < frame_size_sample:
        fft_n *= 2
    
    num_frames = (x.shape[0] - frame_size_sample) // frame_shift_sample + 1
    # num_frames = 1
    
    mfcc = torch.zeros(num_frames, nceps)
    
    for frame_idx in range(num_frames):
        start_idx = frame_idx * frame_shift_sample
        frame = x[start_idx : start_idx + frame_shift_sample]

        hamm_window = torch.hamming_window(len(frame))

        frame_hamm = frame*hamm_window
        
        spec = torch.abs(torch.fft.fft(frame_hamm, fft_n))[:fft_n//2]

        # メルフィルタバンクを作成
        filterbank, _ = melFilterBank(torch.tensor(rate), fft_n, numChannels)
        
        mspec = torch.matmul(spec, filterbank.T)

        dct = DCT(mspec.shape[0])
        ceps = dct.dct(10 * torch.log10(mspec))
        mfcc[frame_idx, :] = ceps[:nceps]
    
    splicing_mfcc = splicing(mfcc, splice=splice)
    
    return splicing_mfcc

確認

うまく微分できている。

rate, x = scipy.io.wavfile.read("data/original/hello.wav")

x = torch.tensor(x)
x = x / 65535
x.requires_grad = True

feat = torch_mfcc(x, rate, splice=0)
feat[0][0].backward()
x.grad, x.grad.sum()
# (tensor([ 17.5747, -34.3700, -20.1528,  ...,   0.0000,   0.0000,   0.0000]), tensor(-555.3243))

本実装が役に立つ研究

Audio Adversarial Example (音声の敵対的サンプル攻撃)という研究分野がある。これは、人間が聞いたときと音声認識モデルが書き起こすときとで異なるテキストとなる音声を生成する研究を表す。詳細は割愛するが、MFCCを入力するタイプの音声認識モデルに対してこの攻撃をおこなおうとするとき、MFCCから元の音声波形までの勾配グラフが必要となり、本実装が役に立つ。

参考

メルフィルタバンク

離散コサイン変換

その他全般

※実装完了から時間がたっての投稿ですので参考に漏れがある可能性があります。漏れを発見された場合、ご一報いただけますと幸いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?