# 0. ざっくりいうと
- Stacked LSTMをChainerで書いた
- それを使って自動作曲してみた
- こうなった → 再生 (注意!すぐに音声が流れます)
1. LSTMとは
以下を参照。
- Understanding LSTM Networks
- LSTMネットワークの概要
- わかるLSTM ~ 最近の動向と共に
- Recurrent Neural Networks
- 【深層学習】再帰ニューラルネットワークに関する良ページまとめ [DW 5日目]
2. Stacked LSTMとは
LSTMを多層に重ねたニューラルネット。多層にすることによって、各レイヤーで長い相関と短い相関を学習できると期待されている。
ちなみに、LSMTを縦横方向につなげて多次元化したGrid LSTMというネットワークもある。Wikipediaの文字予測タスクや中国語翻訳タスクで良い性能を出しているらしい。
3. Chainerのコード
下図のようなニューラルネットを作った。
入力層と出力層はOne-hotベクトルである。4つある中間層(オレンジ)がLSTM層である。
class rnn(Chain):
state = {}
def __init__(self, n_vocab, n_units):
print n_vocab, n_units
super(rnn, self).__init__(
l_embed = L.EmbedID(n_vocab, n_units),
l1_x = L.Linear(n_units, 4 * n_units),
l1_h = L.Linear(n_units, 4 * n_units),
l2_x = L.Linear(n_units, 4 * n_units),
l2_h = L.Linear(n_units, 4 * n_units),
l3_x=L.Linear(n_units, 4 * n_units),
l3_h=L.Linear(n_units, 4 * n_units),
l4_x=L.Linear(n_units, 4 * n_units),
l4_h=L.Linear(n_units, 4 * n_units),
l_umembed = L.Linear(n_units, n_vocab)
)
def forward(self, x, t, train=True, dropout_ratio=0.5):
h0 = self.l_embed(x)
c1, h1 = F.lstm(
self.state['c1'],
F.dropout( self.l1_x(h0), ratio=dropout_ratio, train=train ) + self.l1_h(self.state['h1'])
)
c2, h2 = F.lstm(
self.state['c2'],
F.dropout( self.l2_x(h1), ratio=dropout_ratio, train=train ) + self.l2_h(self.state['h2'])
)
c3, h3 = F.lstm(
self.state['c3'],
F.dropout( self.l3_x(h2), ratio=dropout_ratio, train=train ) + self.l3_h(self.state['h3'])
)
c4, h4 = F.lstm(
self.state['c4'],
F.dropout( self.l4_x(h3), ratio=dropout_ratio, train=train ) + self.l4_h(self.state['h4'])
)
y = self.l_umembed(h4)
self.state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2, 'c3': c3, 'h3': h3, 'c4': c4, 'h4': h4}
if train:
return F.softmax_cross_entropy(y, t), F.accuracy(y, t)
else:
return F.softmax(y), y.data
def initialize_state(self, n_units, batchsize=1, train=True):
for name in ('c1', 'h1', 'c2', 'h2', 'c3', 'h3', 'c4', 'h4'):
self.state[name] = Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train)
4. 実験 (音楽の自動生成の性能向上を試してみた)
前回の記事(RNN + LSTMで自動作曲してみた [DW 1日目] )の自動作曲の性能を高めたい。前回は2層のLSTMsだったが、それを上述の4層のLSTMsに置き換えて再学習させてみた。
学習データ
前回と同じmidiデータを使った。ただし、midiはテキストデータの形式に直す必要がある。テキスト化するコードも以下に載せておいた。以下のコードで、python midi2text.py --midi foo.midi
とすれば、トラックを1つだけ抜き出してテキスト化することができる。
0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00・・・
のように、デルタタイム、ステータスバイト、ステータスバイトに続く1or2バイトを、アンダースコアで連結したもの(「チャンク」と呼ぶ)が半角スペース区切りで並ぶテキストデータが生成される。このテキストデータを学習データとして上述のLSTMで学習させた。
学習曲線
曲の生成
学習後、同じネットワークを用いてチャンクの系列を生成させた。生成したチャンク列を、後述のコードを用いてmidiファイル化した。
こうなった → 再生 (注意!すぐに音声が流れます)
感想
曲はなかなか良かった。
ただ、素になっているmidiの影響がかなり強い(残っている)と思われる。また、Stacked LSTMによって曲が構造化させることを期待していたが想像していたよりも単調だった。曲に動きをもたせるには、すべて自動で生成させるのではなく、系列の出力時にイレギュラーなノイズ(それまで続いていた系列を転移させるような音)を意図的にいれるなどしたほうが良い効果が得られると思う。
作曲という観点から見れば、音選びの他にも課題がある。音色の調整やエフェクトの付与、複数の楽器でのセッションなども課題だ。
コード
- midi→テキスト (
デルタタイム_ステータスバイト_ステータスバイトに続く実データ
の形式)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')
def is_eq_0x2f(b):
return int(b2a_hex(b), 16) == int('2f', 16)
def is_gte_0x80(b):
return int(b2a_hex(b), 16) >= int('80', 16)
def is_eq_0xff(b):
return int(b2a_hex(b), 16) == int('ff', 16)
def is_eq_0xf0(b):
return int(b2a_hex(b), 16) == int('f0', 16)
def is_eq_0xf7(b):
return int(b2a_hex(b), 16) == int('f7', 16)
def is_eq_0x8n(b):
return int(b2a_hex(b), 16) >= int('80', 16) and int(b2a_hex(b), 16) <= int('8f', 16)
def is_eq_0x9n(b):
return int(b2a_hex(b), 16) >= int('90', 16) and int(b2a_hex(b), 16) <= int('9f', 16)
def is_eq_0xan(b): # An: 3byte
return int(b2a_hex(b), 16) >= int('a0', 16) and int(b2a_hex(b), 16) <= int('af', 16)
def is_eq_0xbn(b): # Bn: 3byte
return int(b2a_hex(b), 16) >= int('b0', 16) and int(b2a_hex(b), 16) <= int('bf', 16)
def is_eq_0xcn(b): # Cn: 2byte
return int(b2a_hex(b), 16) >= int('c0', 16) and int(b2a_hex(b), 16) <= int('cf', 16)
def is_eq_0xdn(b): # Dn: 2byte
return int(b2a_hex(b), 16) >= int('d0', 16) and int(b2a_hex(b), 16) <= int('df', 16)
def is_eq_0xen(b): # En: 3byte
return int(b2a_hex(b), 16) >= int('e0', 16) and int(b2a_hex(b), 16) <= int('ef', 16)
def is_eq_0xfn(b):
return int(b2a_hex(b), 16) >= int('f0', 16) and int(b2a_hex(b), 16) <= int('ff', 16)
def mutable_lengths_to_int(bs):
length = 0
for i, b in enumerate(bs):
if is_gte_0x80(b):
length += ( int(b2a_hex(b), 16) - int('80', 16) ) * pow(int('80', 16), len(bs) - i - 1)
else:
length += int(b2a_hex(b), 16)
return length
def int_to_mutable_lengths(length):
length = int(length)
bs = []
append_flag = False
for i in range(3, -1, -1):
a = length / pow(int('80', 16), i)
length -= a * pow(int('80', 16), i)
if a > 0:
append_flag = True
if append_flag:
if i > 0:
bs.append(hex(a + int('80', 16))[2:].zfill(2))
else:
bs.append(hex(a)[2:].zfill(2))
return bs if len(bs) > 0 else ['00']
def read_midi(path_to_midi):
midi = open(path_to_midi, 'rb')
data = {'header': [], 'tracks': []}
track = {'header': [], 'chunks': []}
chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}
current_status = None
"""
Load data.header
"""
bs = midi.read(14)
data['header'] = [b for b in bs]
while 1:
"""
Load data.tracks[0].header
"""
if len(track['header']) == 0:
bs = midi.read(8)
if bs == '':
break
track['header'] = [b for b in bs]
"""
Load data.tracks[0].chunks[0]
"""
# delta time
# ----------
b = midi.read(1)
while 1:
chunk['delta'].append(b)
if is_gte_0x80(b):
b = midi.read(1)
else:
break
# status
# ------
b = midi.read(1)
if is_gte_0x80(b):
chunk['status'].append(b)
current_status = b
else:
midi.seek(-1, os.SEEK_CUR)
chunk['status'].append(current_status)
# meta and length
# ---------------
if is_eq_0xff(current_status): # meta event
b = midi.read(1)
chunk['meta'].append(b)
b = midi.read(1)
while 1:
chunk['length'].append(b)
if is_gte_0x80(b):
b = midi.read(1)
else:
break
length = mutable_lengths_to_int(chunk['length'])
elif is_eq_0xf0(current_status) or is_eq_0xf7(current_status): # sysex event
b = midi.read(1)
while 1:
chunk['length'].append(b)
if is_gte_0x80(b):
b = midi.read(1)
else:
break
length = mutable_lengths_to_int(chunk['length'])
else: # midi event
if is_eq_0xcn(current_status) or is_eq_0xdn(current_status):
length = 1
else:
length = 2
# body
# ----
for i in range(0, length):
b = midi.read(1)
chunk['body'].append(b)
track['chunks'].append(chunk)
if is_eq_0xff(chunk['status'][0]) and is_eq_0x2f(chunk['meta'][0]):
data['tracks'].append(track)
track = {'header': [], 'chunks': []}
chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}
return data
def write_text(tracks):
midi = open('out.txt', 'w')
for track in tracks:
for chunks in track:
midi.write('{} '.format(chunks))
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='audio RNN')
parser.add_argument('--midi', type=unicode, default='', help='path to the MIDI file')
args = parser.parse_args()
data = read_midi(args.midi)
# extract midi track
track_list = [1] # ← 抜き出したいトラック番号
tracks = []
for n in track_list:
raw_data = []
chunks = data['tracks'][n]['chunks']
for i in range(0, len(chunks)):
chunk = chunks[i]
if is_eq_0xff(chunk['status'][0]) or \
is_eq_0xf0(chunk['status'][0]) or \
is_eq_0xf7(chunk['status'][0]) :
continue
raw_data.append('_'.join(
[str(mutable_lengths_to_int(chunk['delta']))] +
[str(b2a_hex(chunk['status'][0]))] +
[str(b2a_hex(body)) for body in chunk['body']]
))
tracks.append(raw_data)
write_text(tracks)
- テキスト(
デルタタイム_ステータスバイト_ステータスバイトに続く実データ
の形式) → midi
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')
def int_to_mutable_lengths(length):
length = int(length)
bs = []
append_flag = False
for i in range(3, -1, -1):
a = length / pow(int('80', 16), i)
length -= a * pow(int('80', 16), i)
if a > 0:
append_flag = True
if append_flag:
if i > 0:
bs.append(hex(a + int('80', 16))[2:].zfill(2))
else:
bs.append(hex(a)[2:].zfill(2))
return bs if len(bs) > 0 else ['00']
def write_midi(tracks):
print len(tracks)
midi = open('out.midi', 'wb')
"""
MIDI Header
"""
header_bary = bytearray([])
header_bary.extend([0x4d, 0x54, 0x68, 0x64, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00])
header_bary.extend([int(hex(len(tracks))[2:].zfill(4)[i:i+2], 16) for i in range(0, 4, 2)])
header_bary.extend([0x01, 0xe0])
midi.write(header_bary)
for track in tracks:
track_bary = bytearray([])
for chunk in track:
# It is assumed that each chunk consists of just 4 elements
if len(chunk.split('_')) != 4:
continue
int_delta, status, data1, data2 = chunk.split('_')
if status[0] == '8' or status[0] == '9' or status[0] == 'a' or status[0] == 'b' or status[0] == 'e': # 3byte
delta = int_to_mutable_lengths(int_delta)
track_bary.extend([int(d, 16) for d in delta])
track_bary.extend([int(status, 16)])
track_bary.extend([int(data1, 16)])
track_bary.extend([int(data2, 16)])
elif status[0] == 'c' or status[0] == 'd':
delta = int_to_mutable_lengths(int_delta)
track_bary.extend([int(d, 16) for d in delta])
track_bary.extend([int(status, 16)])
track_bary.extend([int(data1, 16)])
else:
print status[0]
"""
Track header
"""
header_bary = bytearray([])
header_bary.extend([0x4d, 0x54, 0x72, 0x6b])
header_bary.extend([int(hex(len(track_bary)+4)[2:].zfill(8)[i:i+2], 16) for i in range(0, 8, 2)])
midi.write(header_bary)
"""
Track body
"""
print len(track_bary)
midi.write(track_bary)
"""
Track footer
"""
footer_bary = bytearray([])
footer_bary.extend([0x00, 0xff, 0x2f, 0x00])
midi.write(footer_bary)
if __name__ == '__main__':
# ↓ 「デルタタイム_ステータスバイト_ステータスバイトに続く実データ」の形式をスペース区切りで並べる
# ランニングステータスが含まれているとうまく動きません。。。
txt = '0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00 0_80_43_00 0_90_51_64 0_90_45_64 480_80_51_00 0_80_45_00 0_90_4c_64 0_90_44_64 120_80_4c_00 0_80_44_00 0_90_4f_64 0_90_43_64 60_80_4f_00 0_80_43_00 0_90_4d_64 0_90_41_64 120_80_4d_00'
tracks = [txt.split(' ')]
write_midi(tracks)