63
71

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

深層学習で自動作曲(Stacked LSTM編) [DW 6日目]

Last updated at Posted at 2016-05-05

# 0. ざっくりいうと

1. LSTMとは

以下を参照。

2. Stacked LSTMとは

 LSTMを多層に重ねたニューラルネット。多層にすることによって、各レイヤーで長い相関と短い相関を学習できると期待されている。
 ちなみに、LSMTを縦横方向につなげて多次元化したGrid LSTMというネットワークもある。Wikipediaの文字予測タスクや中国語翻訳タスクで良い性能を出しているらしい。

3. Chainerのコード

下図のようなニューラルネットを作った。

入力層と出力層はOne-hotベクトルである。4つある中間層(オレンジ)がLSTM層である。

rnn.py
class rnn(Chain):
  state = {}

  def __init__(self, n_vocab, n_units):
    print n_vocab, n_units
    super(rnn, self).__init__(
      l_embed = L.EmbedID(n_vocab, n_units),
      l1_x = L.Linear(n_units, 4 * n_units),
      l1_h = L.Linear(n_units, 4 * n_units),
      l2_x = L.Linear(n_units, 4 * n_units),
      l2_h = L.Linear(n_units, 4 * n_units),
      l3_x=L.Linear(n_units, 4 * n_units),
      l3_h=L.Linear(n_units, 4 * n_units),
      l4_x=L.Linear(n_units, 4 * n_units),
      l4_h=L.Linear(n_units, 4 * n_units),
      l_umembed = L.Linear(n_units, n_vocab)
    )

  def forward(self, x, t, train=True, dropout_ratio=0.5):
    h0 = self.l_embed(x)
    c1, h1 = F.lstm(
      self.state['c1'],
      F.dropout( self.l1_x(h0), ratio=dropout_ratio, train=train ) + self.l1_h(self.state['h1'])
    )
    c2, h2 = F.lstm(
      self.state['c2'],
      F.dropout( self.l2_x(h1), ratio=dropout_ratio, train=train ) + self.l2_h(self.state['h2'])
    )
    c3, h3 = F.lstm(
      self.state['c3'],
      F.dropout( self.l3_x(h2), ratio=dropout_ratio, train=train ) + self.l3_h(self.state['h3'])
    )
    c4, h4 = F.lstm(
      self.state['c4'],
      F.dropout( self.l4_x(h3), ratio=dropout_ratio, train=train ) + self.l4_h(self.state['h4'])
    )
    y = self.l_umembed(h4)
    self.state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2, 'c3': c3, 'h3': h3, 'c4': c4, 'h4': h4}
    if train:
      return F.softmax_cross_entropy(y, t), F.accuracy(y, t)
    else:
      return F.softmax(y), y.data

  def initialize_state(self, n_units, batchsize=1, train=True):
    for name in ('c1', 'h1', 'c2', 'h2', 'c3', 'h3', 'c4', 'h4'):
      self.state[name] = Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train)

4. 実験 (音楽の自動生成の性能向上を試してみた)

 前回の記事(RNN + LSTMで自動作曲してみた [DW 1日目] )の自動作曲の性能を高めたい。前回は2層のLSTMsだったが、それを上述の4層のLSTMsに置き換えて再学習させてみた。

学習データ

 前回と同じmidiデータを使った。ただし、midiはテキストデータの形式に直す必要がある。テキスト化するコードも以下に載せておいた。以下のコードで、python midi2text.py --midi foo.midiとすれば、トラックを1つだけ抜き出してテキスト化することができる。
 0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00・・・のように、デルタタイム、ステータスバイト、ステータスバイトに続く1or2バイトを、アンダースコアで連結したもの(「チャンク」と呼ぶ)が半角スペース区切りで並ぶテキストデータが生成される。このテキストデータを学習データとして上述のLSTMで学習させた。

学習曲線

曲の生成

 学習後、同じネットワークを用いてチャンクの系列を生成させた。生成したチャンク列を、後述のコードを用いてmidiファイル化した。

こうなった → 再生 (注意!すぐに音声が流れます)

感想

 曲はなかなか良かった。
 ただ、素になっているmidiの影響がかなり強い(残っている)と思われる。また、Stacked LSTMによって曲が構造化させることを期待していたが想像していたよりも単調だった。曲に動きをもたせるには、すべて自動で生成させるのではなく、系列の出力時にイレギュラーなノイズ(それまで続いていた系列を転移させるような音)を意図的にいれるなどしたほうが良い効果が得られると思う。
 作曲という観点から見れば、音選びの他にも課題がある。音色の調整やエフェクトの付与、複数の楽器でのセッションなども課題だ。

コード

  • midi→テキスト (デルタタイム_ステータスバイト_ステータスバイトに続く実データの形式)
midi2text.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')

def is_eq_0x2f(b):
  return int(b2a_hex(b), 16) == int('2f', 16)

def is_gte_0x80(b):
  return int(b2a_hex(b), 16) >= int('80', 16)

def is_eq_0xff(b):
  return int(b2a_hex(b), 16) == int('ff', 16)

def is_eq_0xf0(b):
  return int(b2a_hex(b), 16) == int('f0', 16)

def is_eq_0xf7(b):
  return int(b2a_hex(b), 16) == int('f7', 16)

def is_eq_0x8n(b):
  return int(b2a_hex(b), 16) >= int('80', 16) and int(b2a_hex(b), 16) <= int('8f', 16)

def is_eq_0x9n(b):
  return int(b2a_hex(b), 16) >= int('90', 16) and int(b2a_hex(b), 16) <= int('9f', 16)

def is_eq_0xan(b): # An: 3byte
  return int(b2a_hex(b), 16) >= int('a0', 16) and int(b2a_hex(b), 16) <= int('af', 16)

def is_eq_0xbn(b): # Bn: 3byte
  return int(b2a_hex(b), 16) >= int('b0', 16) and int(b2a_hex(b), 16) <= int('bf', 16)

def is_eq_0xcn(b): # Cn: 2byte
  return int(b2a_hex(b), 16) >= int('c0', 16) and int(b2a_hex(b), 16) <= int('cf', 16)

def is_eq_0xdn(b): # Dn: 2byte
  return int(b2a_hex(b), 16) >= int('d0', 16) and int(b2a_hex(b), 16) <= int('df', 16)

def is_eq_0xen(b): # En: 3byte
  return int(b2a_hex(b), 16) >= int('e0', 16) and int(b2a_hex(b), 16) <= int('ef', 16)

def is_eq_0xfn(b):
  return int(b2a_hex(b), 16) >= int('f0', 16) and int(b2a_hex(b), 16) <= int('ff', 16)

def mutable_lengths_to_int(bs):
  length = 0
  for i, b in enumerate(bs):
    if is_gte_0x80(b):
      length += ( int(b2a_hex(b), 16) - int('80', 16) ) * pow(int('80', 16), len(bs) - i - 1)
    else:
      length += int(b2a_hex(b), 16)
  return length

def int_to_mutable_lengths(length):
  length = int(length)
  bs = []
  append_flag = False
  for i in range(3, -1, -1):
    a = length / pow(int('80', 16), i)
    length -= a * pow(int('80', 16), i)
    if a > 0:
      append_flag = True
    if append_flag:
      if i > 0:
        bs.append(hex(a + int('80', 16))[2:].zfill(2))
      else:
        bs.append(hex(a)[2:].zfill(2))
  return bs if len(bs) > 0 else ['00']

def read_midi(path_to_midi):
  midi = open(path_to_midi, 'rb')
  data = {'header': [], 'tracks': []}
  track = {'header': [], 'chunks': []}
  chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}
  current_status = None

  """
  Load data.header
  """
  bs = midi.read(14)
  data['header'] = [b for b in bs]

  while 1:
    """
    Load data.tracks[0].header
    """
    if len(track['header']) == 0:
      bs = midi.read(8)
      if bs == '':
        break
      track['header'] = [b for b in bs]

    """
    Load data.tracks[0].chunks[0]
    """
    # delta time
    # ----------
    b = midi.read(1)
    while 1:
      chunk['delta'].append(b)
      if is_gte_0x80(b):
        b = midi.read(1)
      else:
        break

    # status
    # ------
    b = midi.read(1)
    if is_gte_0x80(b):
      chunk['status'].append(b)
      current_status = b
    else:
      midi.seek(-1, os.SEEK_CUR)
      chunk['status'].append(current_status)

    # meta and length
    # ---------------
    if is_eq_0xff(current_status): # meta event
      b = midi.read(1)
      chunk['meta'].append(b)
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    elif is_eq_0xf0(current_status) or is_eq_0xf7(current_status): # sysex event
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    else: # midi event
      if is_eq_0xcn(current_status) or is_eq_0xdn(current_status):
        length = 1
      else:
        length = 2

    # body
    # ----
    for i in range(0, length):
      b = midi.read(1)
      chunk['body'].append(b)

    track['chunks'].append(chunk)


    if is_eq_0xff(chunk['status'][0]) and is_eq_0x2f(chunk['meta'][0]):
      data['tracks'].append(track)
      track = {'header': [], 'chunks': []}
    chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}

  return data

def write_text(tracks):
  midi = open('out.txt', 'w')
  for track in tracks:
    for chunks in track:
      midi.write('{} '.format(chunks))

if __name__ == '__main__':
  from argparse import ArgumentParser
  parser = ArgumentParser(description='audio RNN')
  parser.add_argument('--midi', type=unicode, default='', help='path to the MIDI file')
  args = parser.parse_args()
  
  data = read_midi(args.midi)

  # extract midi track
  track_list = [1] # ← 抜き出したいトラック番号

  tracks = []
  for n in track_list:
    raw_data = []
    chunks = data['tracks'][n]['chunks']
    for i in range(0, len(chunks)):
      chunk = chunks[i]
      if is_eq_0xff(chunk['status'][0]) or \
         is_eq_0xf0(chunk['status'][0]) or \
         is_eq_0xf7(chunk['status'][0]) :
        continue
      raw_data.append('_'.join(
        [str(mutable_lengths_to_int(chunk['delta']))] +
        [str(b2a_hex(chunk['status'][0]))] +
        [str(b2a_hex(body)) for body in chunk['body']]
      ))
    tracks.append(raw_data)

  write_text(tracks)
  • テキスト(デルタタイム_ステータスバイト_ステータスバイトに続く実データの形式) → midi
text2midi
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')

def int_to_mutable_lengths(length):
  length = int(length)
  bs = []
  append_flag = False
  for i in range(3, -1, -1):
    a = length / pow(int('80', 16), i)
    length -= a * pow(int('80', 16), i)
    if a > 0:
      append_flag = True
    if append_flag:
      if i > 0:
        bs.append(hex(a + int('80', 16))[2:].zfill(2))
      else:
        bs.append(hex(a)[2:].zfill(2))
  return bs if len(bs) > 0 else ['00']

def write_midi(tracks):
  print len(tracks)
  midi = open('out.midi', 'wb')

  """
  MIDI Header
  """
  header_bary = bytearray([])
  header_bary.extend([0x4d, 0x54, 0x68, 0x64, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00])
  header_bary.extend([int(hex(len(tracks))[2:].zfill(4)[i:i+2], 16) for i in range(0, 4, 2)])
  header_bary.extend([0x01, 0xe0])
  midi.write(header_bary)

  for track in tracks:
    track_bary = bytearray([])
    for chunk in track:
      # It is assumed that each chunk consists of just 4 elements
      if len(chunk.split('_')) != 4:
        continue
      int_delta, status, data1, data2 = chunk.split('_')

      if status[0] == '8' or status[0] == '9' or status[0] == 'a' or status[0] == 'b' or status[0] == 'e': # 3byte
        delta = int_to_mutable_lengths(int_delta)
        track_bary.extend([int(d, 16) for d in delta])
        track_bary.extend([int(status, 16)]) 
        track_bary.extend([int(data1, 16)])  
        track_bary.extend([int(data2, 16)])  
      elif status[0] == 'c' or status[0] == 'd':
        delta = int_to_mutable_lengths(int_delta)
        track_bary.extend([int(d, 16) for d in delta])
        track_bary.extend([int(status, 16)]) 
        track_bary.extend([int(data1, 16)])  
      else:
        print status[0]

    """
    Track header
    """
    header_bary = bytearray([])
    header_bary.extend([0x4d, 0x54, 0x72, 0x6b])
    header_bary.extend([int(hex(len(track_bary)+4)[2:].zfill(8)[i:i+2], 16) for i in range(0, 8, 2)])
    midi.write(header_bary)

    """
    Track body
    """
    print len(track_bary)
    midi.write(track_bary)

    """
    Track footer
    """
    footer_bary = bytearray([])
    footer_bary.extend([0x00, 0xff, 0x2f, 0x00])
    midi.write(footer_bary)

if __name__ == '__main__':

  # ↓ 「デルタタイム_ステータスバイト_ステータスバイトに続く実データ」の形式をスペース区切りで並べる
  # ランニングステータスが含まれているとうまく動きません。。。
  txt = '0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00 0_80_43_00 0_90_51_64 0_90_45_64 480_80_51_00 0_80_45_00 0_90_4c_64 0_90_44_64 120_80_4c_00 0_80_44_00 0_90_4f_64 0_90_43_64 60_80_4f_00 0_80_43_00 0_90_4d_64 0_90_41_64 120_80_4d_00'
  tracks = [txt.split(' ')]
  write_midi(tracks)

リンク

Deep Learning を徹底的に勉強してみる [DW 0日目]

63
71
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
63
71

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?