0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

数字の読み上げデータセットを学習

Last updated at Posted at 2021-06-01

0から9の数字を複数人が読み上げた音声データである spoken digit に対して LSTM で学習を行った。

以下のコードを google colab にコピペすると動くようになっている。colabのランタイムをGPUにすると1分ですべて終わる。

import glob
import requests
import IPython
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm.notebook as tqdm
import tensorflow as tf
import itertools
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import pandas as pd
import tarfile
  

def sound_download():
  # urlからデータをダウンロードする
  download_url = "https://github.com/Jakobovski/free-spoken-digit-dataset/archive/v1.0.9.tar.gz"
  filename = "free-spoken-digit-dataset-1.0.9.tar.gz"
  content_ = requests.get(download_url).content
  with open(filename, 'wb') as f: # wb でバイト型を書き込める
    f.write(content_)

  # 解凍すると free-spoken-digit-dataset-1.0.9 ディレクトリが現れる
  with tarfile.open(filename, 'r:gz') as t:
      t.extractall(path="./")
    

def sound_preprocess(path):
  # 音声の読み込み
  y, sr = librosa.load(path, sr=None)
  # メルスペクトログラム(人間の聴覚に適したスペクトログラム)
  S = librosa.feature.melspectrogram(y=y, sr=sr)
  S_dB = librosa.power_to_db(S, ref=np.max)
  return S_dB


# 音声データのダウンロード
sound_download()


# recordingsに入っている数字の音声データ2500個をサウンドスペクトログラムとして取り出す
db_list = []
path_list = sorted(glob.glob("free-spoken-digit-dataset-1.0.9/recordings/*"))
for path in path_list:
  db_list.append(sound_preprocess(path))


# pad_sequencesを用いて1でパディング(デシベルはー80〜0の範囲であるため、重複しない1を用いる)
padded_inputs = []
for db in db_list:
  _padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(db, padding="post", maxlen=36, value=1.0)
  padded_inputs.append(_padded_inputs)


# 後でLSTMを用いるために軸の入れ替えを行う (time, hidden)
X = np.array(padded_inputs)
X = np.transpose(X, (0, 2, 1))

# ディレクトリ内はラベル0からラベル9まで250個ずつ並んでいる
_y = [[i] * 250 for i in range(10)]
y = list(itertools.chain.from_iterable(_y))
y = np.array(y)  # [0, 0, ..., 1, 1, ... ... 9, 9]


# ラベルのone hot encodeを行う
enc = OneHotEncoder(sparse=False)
y = enc.fit_transform(y[:, np.newaxis])

# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y)

次に加工したデータを元にモデルの学習を行う。

import tensorflow as tf
from keras.layers import Masking, LSTM, Input, Dense, Activation, Dropout
from keras.models import Model


inputs = Input(shape=X.shape[1:])
# Maskingで対応する時間ステップを処理中に無視する
x = Masking(mask_value=1.0)(inputs)
x = LSTM(128)(x)
x = Dense(10)(x)
x = Activation('softmax')(x)

model = Model(inputs, x)
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# validation_dataはシャッフルしていないとaccuracyが0と出てしまう
history = model.fit(X_train, y_train, batch_size=32, epochs=100, validation_data=(X_test, y_test))
pd.DataFrame(history.history).loc[:, ['accuracy', 'val_accuracy']].plot(ylim=(0, 1))

精度はだいたい98%くらいだった。引き続き
image.png

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?