1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Synthesize Human Speech with WaveNet の tensorflow実装(モデル定義)

Posted at

前回のデータ前処理に続いてモデルの定義を行います。モデルの構造は以下の図の通りです。


詳細はここを。

##Resnet&UpsampleNet

class ResidualNet:
    def __init__(self, n_loop=2, layers_per_loop=10, k_size=2, res_channel=64, skip_channel=256, conditioning=True, name='ResNet'):
        self.dilations = [2**i for i in range(layers_per_loop)]*n_loop
        self.k_size = k_size
        self.res_channel = res_channel
        self.skip_channel = skip_channel
        self.conditioning = conditioning
        self.name = name
    
    def resblock(self, input, dilation, cond=None, name='ResBlock'):
        with tf.variable_scope(name):
            length = input.shape[1]
            x = tf.pad(input, [[0, 0],[dilation * (self.k_size-1), 0], [0, 0]])
            x = tf.layers.conv1d(x, self.res_channel*2, self.k_size, dilation_rate=dilation)
            x = x[:,:length,:]
            if self.conditioning and cond is not None:
                x += cond
            tanh_z, sig_z = tf.split(x, 2, 2)
            z = tf.tanh(tanh_z)*tf.sigmoid(sig_z)
            res = tf.layers.conv1d(z, self.res_channel, 1) + input
            skip_connection = tf.layers.conv1d(z, self.skip_channel, 1)
        return res, skip_connection
    
    def __call__(self, input, condition=None, activation=tf.nn.relu):
        with tf.variable_scope(self.name):
            x = input
            for idx, (r, c) in enumerate(zip(self.dilations, condition)):
                x, skip = self.resblock(x, r, c, name='ResBlock_%d'%(idx+1))
                if idx == 0:
                    skip_connection = skip
                else:
                    skip_connection += skip
            if activation:
                skip_connection = activation(skip_connection)
        return skip_connection

class UpsampleNet:
    def __init__(self, layers, out_channels, channels=[128, 128], scales=[16, 16], name='Upsample'):
        self.layers = layers
        self.out_channels = out_channels
        self.channels = channels
        self.scales = scales
        self.name = name
        assert len(self.channels) == len(self.scales)
    
    def upsampling(self, input):
        with tf.variable_scope(self.name):
            conditions = []
            with tf.variable_scope('Deconvolution'):
                x = tf.expand_dims(input, 1)
                for c, s in zip(self.channels, self.scales):
                    x = tf.layers.conv2d_transpose(x, c, (1, s), (1, s))
                    x = tf.nn.relu(x)
                x = tf.squeeze(x, 1)
            with tf.variable_scope('Encode_feature'):
                for _ in range(self.layers):
                    conditions.append(tf.layers.conv1d(x, self.out_channels, 1))
            return conditions
    
    def __call__(self, input):
        return self.upsampling(input)

基本的に図の通りに実装しただけなので詳細は割愛。
Upsampleはconditioningの入力であるスペクトログラムのデータ長が異なるため転置畳み込みを行っています。

コードはhttps://github.com/Sakai0127/Wavenetにあります。
 
##おまけ
https://sakai0127.github.io/Wavenet/
このモデルのグラフをtensorboardで表示したページのリンクです。Chromeでしか動作確認をしていないので他のブラウザだと表示されないかもしれませんが。(htmlのwidth, height指定に -webkit-fill-availableを使っているから?)

これを作っていたせいで力尽きたところはありますが。私の場合、各層でのinputとoutputはソースコードを読めばわかりますけどネットワーク全体がどうなっているのかはソースからだけだとイメージし辛いんですよね。幸いtensorboardという素晴らしいツールがあるので視覚的に理解しやすいと思いこれを作成しました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?