Help us understand the problem. What is going on with this article?

RNN + LSTMで自動作曲してみた [DW 1日目]

More than 3 years have passed since last update.

ざっくりいうと

この記事の目的

RNNとLSTMを理解する

 まず、RNNとLSTMの理解。実は今はゴールデンウィーク真っ最中なのだが、今年は有給と合わせて10連休をもらったので、Deep Learning強化週間としてDL系の記事を1日ペースで書いていくことにした。本記事はその第一弾に位置づけられる。第一弾ということでインパクトがあるテーマが良いと思い、RNN+LSTMを使って自動作曲というテーマにしてみた。企画の概要についてはコチラ

テキスト以外の自動生成を試す

 また、世の中には文章を自動生成させて楽しむ記事がすでに大量にあり、これをやっても完全な二番煎じにしかならないと思った。一方、音楽についてはまだあまり書かれていなかったので音楽の自動生成を試そうと思った次第だ。

 最後に、割りと重要なのだが、こういう発信をしていくことで自身の社内存在価値も向上するかなと。そういう裏の目的もあったりする。

RNNとLSTMの概論

RNN

 RNNは、系列データを入力として、その系列の次に来る尤もらしいデータを出力する。単語の並びを入力とすれば、次単語を予測することができるし、音符や休符の並びを入力とすれば次の音符や休符の並びを予測することもできる。今回は、後者の音楽の自動生成にチャレンジしてみた。成果については、記事の後半で述べるとして、ここではまずRNNとLSTMの概論をまとめておく。ただし、自分の覚書用なので読み物としては成立していないと思う。読み物として読みたい人は、参考文献のリンク先を見ることをすすめる。
 
 さて、最も基本のRNNは多層パーセプトロンの中間層に、自分自身へループする重みを加えたものである。時刻 $t$ の入力層ベクトルを ${\bf x}^t$ 、中間層ベクトルを ${\bf h}^t$ 、層間の重みを $ {\bf W}^{in}$ 、中間層の自分自身へループする重みを ${\bf W}^{h}$ とすると、次の時刻 $t+1$ における同じ中間層への入力は

 inputs = {\bf W}^{in}{\bf x}^{t+1} + {\bf W}^{h}{\bf h}^{t} 

となる。つまり、1つ前の時刻 $\bf t$ の情報が必要になる。これが、「自分自身へのループ」と呼んでいる部分である。
 ちなみに、ここで「時刻」と呼んでいるのは、あくまで系列データの系列番号のこと。文章中の単語や楽譜中の音符の場合、それが前から数えて何番目にあるかを意味する。
 RNNの弱点は、その重みを得るための誤差逆伝播法の計算が困難ということである。困難である理由は、RNNに対する誤差逆伝播法を時間方向に拡張する(これをBPTT: BackPropagation Through Timeという)と、非常に深いネットワークとなり、浅い層の学習に必要な誤差勾配が消失するため。RNNではない普通の多層パーセプトロンでも、層が非常に深い場合はこの勾配消失が問題になってくる。

LSTM

 したがって、基本的なRNNは長期記憶ができないという弱点をもつ。この弱点を克服するために編み出されたのがLSTM (Long Short-Term Memory)。LSTMは、基本的なRNNがもつ中間層の各ノードをLSTMブロックと呼ばれる要素で置換したものである。LSTMブロックはメモリユニットとも呼ばれる。LSTMへの入力は基本的なRNNの中間層への入力ベクトルの4倍の次元数のベクトルとなる。これは、1つのLSTMブロックが4つの入力を受け取るためである。4つのうち、1つは通常のRNNと同じく、前の層からの入力である。一方、他の3つは、LSTMブロックに特有の「ゲート」と呼ばれる機構の動作のために消費される。3つの入力はそれぞれ、「入力ゲート」、「出力ゲート」、「忘却ゲート」と呼ばれる機構に入力される。入力ゲートは、LSTMブロックに入ってくる入力(前の層からの入力を意味する)を、通すか通さないかを判断する。出力ゲートはその出力版である。そして忘却ゲートは、LSTMが内部に持つ変数(これを内部状態と呼ぶ)を次の時刻に引き継ぐかどうかを判断する役割をもつ。例えば入力ゲートが0かつ忘却ゲートが1の状況が続くと、過去からの入力は遮断され、LSTMブロックの内部状態が長い時刻を超えて引き継がれるようになる。この性質によって、長期記憶を実現することが可能になった。
 各ゲートの詳細な動作については以下が詳しい。

実験1: 自動「作文」をしてみた

Neural Networkのコード

自動作曲に先立ち、まずは自動「作文」をやってみる。実は、この自動作文で扱う「単語」を、音楽の要素に差し替えるだけで自動作曲に応用できるようになる。

 以下にChainer(1.8)を用いて実装したコードを示す。

rnn.py
"""
Recurrent Neural Network with two LSTM layers.
"""
class rnn(Chain):
  state = {}

  def __init__(self, n_vocab, n_units):
    super(rnn, self).__init__(
      l0 = L.EmbedID(n_vocab, n_units),
      l1_x = L.Linear(n_units, 4 * n_units),
      l1_h = L.Linear(n_units, 4 * n_units),
      l2_x = L.Linear(n_units, 4 * n_units),
      l2_h = L.Linear(n_units, 4 * n_units),
      l3 = L.Linear(n_units, n_vocab)
    )

  def forward(self, x, t, train=True, dropout_ratio=0.5):
    h0 = self.l0(x)
    c1, h1 = F.lstm(
      self.state['c1'],
      F.dropout( self.l1_x(h0), ratio=dropout_ratio, train=train ) + self.l1_h(self.state['h1'])
    )
    c2, h2 = F.lstm(
      self.state['c2'],
      F.dropout( self.l2_x(h1), ratio=dropout_ratio, train=train ) + self.l2_h(self.state['h2'])
    )
    y = self.l3(h2)
    self.state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2}
    if train:
      return F.softmax_cross_entropy(y, t), F.accuracy(y, t)
    else:
      return F.softmax(y), y.data

  def initialize_state(self, n_units, batchsize=50, train=True):
    for name in ('c1', 'h1', 'c2', 'h2'):
      self.state[name] = Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train)

 ネットワークの構造は、n_vocab次元の入力層、n_units次元の単語分散表現層、n_units次元の第一LSTMブロック層、n_units次元の第二LSTMブロック層、最後にn_vocab次元の出力層で構成される5層構造である。ただし、入力層と出力層も含めて5層と呼んでいる。
 2層目の単語分散表現層は、単語の分散表現ベクトルを扱う層である。単語の分散表現とは、単語を固定長のベクトルで表現したもので、今ではshifted PMIやword2vecが主流である。自然言語処理でよくやるTipsとして、あらかじめ単語の分散表現をword2vec等で獲得しておき、ネットワークへの単語の入力として獲得済みの単語分散表現を与えるという方法があるが、今回はそうではない。1層目と2層目のリンクとしてChainerが提供するlinks.EmbedID()を採用している。これは分散表現自体を学習していく初期値ランダムのリンクである。
 3、4層目のLSTMブロックへの入力が4倍になっているのは、通常の入力以外に上述の3つのゲートへの入力が存在するためである。
stateという名前のハッシュでは、LSTMブロックの内部状態cと、LSTMブロックの出力値hを保持している。
3層目、4層目の計算にはDropoutを導入している。

学習過程のコード

train.py
# Init model
model = rnn(n_vocab, n_units)
model.initialize_state(n_units, batchsize=batchsize, train=True)

# Setup optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)

for i in xrange(N):
  # Prepare mini batch
  x = chainer.Variable(x_batch.astype(np.int32), volatile=False)
  y = chainer.Variable(y_batch.astype(np.int32), volatile=False)

  # Forward step
  loss, acc = model.forward(x, y, dropout_ratio=0.5, train=True)
  accum_loss += loss

  # Backward step
  if (i + 1) % bprop_len == 0:  # Run truncated BPTT
    optimizer.zero_grads()
    accum_loss.backward()
    accum_loss.unchain_backward()  # truncate
    accum_loss = chainer.Variable(xp.zeros((), dtype=np.float32))

    # L2 regularization
    optimizer.clip_grads(grad_clip)
    optimizer.update()

重み最適化には確率的急降下法の一種のAdamを用いている。Adamは勾配の平均と分散をオンライン推定して学習率を変更していく。重み成分ごとに学習率を持っており、特に重み行列が疎であるときに強いという性質をもつ。Adamは多くの問題で最適な収束を達成することが経験上わかっているので、とりあえずAdamを使っておけば問題ない。ただし、重み行列の要素数に比例してメモリを消費するため、現実的にはメモリ不足を回避するためAdamを採用しないこともある。
 bprop_lenはtruncated BPTTと呼ばれるTipsに使う変数である。過去のどの時点まで履歴を遡るかを指定しする。
 optimizer.clip_grads()ではノルムを上限値に抑えこむ制約を与えている。これによって勾配爆発問題に対処できるようになる。詳しくは以下を参照。

実験1: 自動作文の結果

実験台として、青空文庫で公開されているシャーロック・ホームズシリーズの『緋色の研究』(日本語訳)を学習データとして与えた。(正式名は『緋のエチュード』っていうのか!?)
- コナン・ドイル作品群

前処理として、章番号、ヘッダ、フッタ、《三角かっこ》を削除し、MeCabで分かち書きを済ませておいた。

学習が十分収束したあと、「ホームズ」という単語を入力として与えたところ、以下の様な出力結果を得た。

出力結果
ホームズが 、 「 ほら 、 マーチャ が いわゆる 幻覚 が 、 マーチャ が 掟 を 破る のを 黙っ て いる こと を 実感 し た 。
ノーヴー という 地名 に ジョン ・ フェ リア は 思い当たる もの が あっ た 。 血の気 の 多い 男 が たいへん そう だ な 。

鍵カッコが閉じていないなど、突っ込みどころはあるが、ひとまず自動的に文を生成することに成功したことにする。次はこれを音楽に応用させる。

実験2: 自動作曲してみた

次は曲の自動生成である。先にも書いたとおり、単語を音楽の要素に置き換えるだけで良い。ここで、「音楽の要素」とはmidiファイルのデータチャンクである。midiのデータ形式は以下のページが詳しい。

データチャンクは「デルタタイム」「イベントコード」「メタイベントコード」「データ長」「実データ」で構成されており、実際の演奏に関係するのは「デルタタイム」、「9もしくは8から始まるイベントコード」そして「実データ」のみである。今回は以下のコードで、midiデータを構造化して出力し、演奏に関係するデルタタイムとイベントコードと実データを抽出して適当な文字(アンダースコア等)で連結させた。たとえば、00_90_4080のようになる。これを1つの擬似的な単語とみなし、それをスペース区切りに並べたテキストを作った。例えば以下のようになる。

… 00_90_4080 00_90_5080 00_90_0000 …

これを作文のときと同じように学習データとして学習させた。また、ここでは問題を簡単にするため、1つのトラック(つまり1つの楽器)についてのみ学習の対象としている。元のmidiは、以下を使わせていただいた。

read_midi.py
def is_eq_0x2f(b):
  return int(b2a_hex(b), 16) == int('2f', 16)

def is_gte_0x80(b):
  return int(b2a_hex(b), 16) >= int('80', 16)

def is_eq_0xff(b):
  return int(b2a_hex(b), 16) == int('ff', 16)

def is_eq_0xf0(b):
  return int(b2a_hex(b), 16) == int('f0', 16)

def is_eq_0xf7(b):
  return int(b2a_hex(b), 16) == int('f7', 16)

def is_eq_0xcn(b):
  return int(b2a_hex(b), 16) >= int('c0', 16) and int(b2a_hex(b), 16) <= int('cf', 16)

def mutable_lengths_to_int(bs):
  length = 0
  for i, b in enumerate(bs):
    if is_gte_0x80(b):
      length += ( int(b2a_hex(b), 16) - int('80', 16) ) * pow(int('80', 16), len(bs) - i - 1)
    else:
      length += int(b2a_hex(b), 16)
  return length


"""
出力される形式
{
  'header': [77, 84, ...] # 14 byte
  'tracks': [
    {   # track No.1
      header: [77, 84, ...], # 8 byte
      chunks: [
        {
          delta:  [129, 134, 1],   # mutable
          status: [255],           # 1 byte
          meta:   [1],             # 1 byte
          length: [5],             # mutable
          body:   [1, 2, 3, 4, 5]  # mutable
        },{
          ...
        },{
          ...
        }
      ]
    },{ # track No.2
      ...
    },{ # track No.3
      ...
    }
  ]
}

"""
def read_midi(path_to_midi):
  midi = open(path_to_midi, 'rb')
  data = {'header': [], 'tracks': []}
  track = {'header': [], 'chunks': []}
  chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}
  current_status = None

  """
  Load data.header
  """
  bs = midi.read(14)
  data['header'] = [b for b in bs]

  while 1:
    """
    Load data.tracks[0].header
    """
    if len(track['header']) == 0:
      bs = midi.read(8)
      if bs == '':
        break
      track['header'] = [b for b in bs]

    """
    Load data.tracks[0].chunks[0]
    """
    # delta time
    # ----------
    b = midi.read(1)
    while 1:
      chunk['delta'].append(b)
      if is_gte_0x80(b):
        b = midi.read(1)
      else:
        break

    # status
    # ------
    b = midi.read(1)
    if is_gte_0x80(b):
      chunk['status'].append(b)
      current_status = b
    else:
      midi.seek(-1, os.SEEK_CUR)
      chunk['status'].append(current_status)

    # meta and length
    # ---------------
    if is_eq_0xff(current_status): # meta event
      b = midi.read(1)
      chunk['meta'].append(b)
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    elif is_eq_0xf0(current_status) or is_eq_0xf7(current_status): # sysex event
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    else: # midi event
      if is_eq_0xcn(current_status):
        length = 1
      else:
        length = 2

    # body
    # ----
    for i in range(0, length):
      b = midi.read(1)
      chunk['body'].append(b)

    track['chunks'].append(chunk)


    if is_eq_0xff(chunk['status'][0]) and is_eq_0x2f(chunk['meta'][0]):
      data['tracks'].append(track)
      track = {'header': [], 'chunks': []}
    chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}

  return data

実験2: 自動作曲の結果

十分学習させたあと、適当なチャンクを初期値として与え、データチャンクの系列を生成した。それをmidiの形式に再変換させmidiファイルを作成した。
ここをクリックすると聞ける。(※注意!音が出ます!)

心地良くもないが心地悪くもないといった結果になった。課題としては、単調(短調ではない)すぎるメロディになってしまうという点があげられる。これは自動作文にも言えることで、起承転結がない文がダラダラと続いてしまう傾向がある。この課題は、Stacked LSTMのような構造的なLSTMを用いることで解決できるかもしれない。
 また、改善点としては単一ではなく複数の楽器で学習させるなどもあげられる。

実験を終えて

とにかく、長くなってしまった。文書くのに5時間くらいかかっており、肝心のDLの勉強のスピードが出ないという本末転倒な事態に。しかし、第一発目なので自動作曲という面白みのあるテーマで、一応結果出力までできたので良しとしよう。
 今後はStacked LSTMで今回の自動作曲を強化するか、他のテーマ(画像分類)なんかを試してみようと画策してる。以上。

リンク

Deep Learning を徹底的に勉強してみる[DW 0日目]

komakomako
NTTデータ辞めました
https://www.mahirokazuko.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away