LoginSignup
32

More than 1 year has passed since last update.

TCN (Temporal Convolutional Network) とは?

Last updated at Posted at 2020-07-10

前提

一般的なCNNの知識は既知とします、padding, kernel_size, stride, channelなどのCNN単語はいきなり出てきますが詳しくはCNNで調べてください。

概要

時系列データに対して汎用的かつシンプルで強力なCNN

主にこの論文で言及され始めたネットワーク (正確にはこれよりも前にこの名前は使われているらしいが)で、
TCN : Temporal Convolutional Networkという名前の通り、系列データに対してCNNを用いたネットワークです。
この論文は時系列データのタスクに対して**LSTMやGRUを脳死で使うのではなくCNNを使うことも考えてみたら?**というabstractから始まっていて、TCNが自然言語や音楽などの時系列データに対するタスクでLSTMなどのRNNよりも精度がよくなるという結果が述べられています!
このTCNは、この論文で試されたいくつかのタスクに対して、そのアーキテクチャを変えることなく数々のタスクでより良い精度を示していることから、見出しの通り汎用的でパワフルなネットワークと言えましょう!

アーキテクチャ

実際のアーキテクチャ自体は完全に新しいものというわけではなく、過去のCNNのいいところを、シンプルなアーキテクチャに蒸留するのが目的だ!と書かれてあります。
要するに過去のCNNのいいところを集めて汎用的にしたモデルと言えます。

それでは実際に見ていきましょう!
image.png

音声合成の分野で流行ったWaveNetを知っているひとであればTCNとそっくりですね!
TCNの一層を表しているのが、上の図の(b)でこれを4層積み上げたのが上の図の(a)です。
図(a)の各層はただのCNNの一層のようにしか見えませんが、実際は一つ一つの層が図(b)のようになっていて図(a)は図(b)のブロックを4個直列につなげた形です。上でCNNのいいところを集めたと書いたので実際にどのような利点を継承しているのか見ていきます。

1. dilated convolutional network

図(a)を見ると、畳み込む対象の要素が上に行くにつれてまばらになっていっているのがわかると思います。
このようにdilated convolutional networkは隣接した要素を畳み込むのではなく、間をあけて畳み込んだCNNのことをいいます。
そしてさらにWaveNetやTCNでは図(a)のように層を重ねるにつれてその畳み込む隙間が大きくなっていっています。
実際にはi(i=0,1,2,3....)番目の層では2のi乗個の要素ごとに畳み込まれているのが見て取れます。こうすることでLSTMのような長期的な視点を畳み込むことができるようになります。ちなみにdilatedとは「広げられた」、「膨張された」などを意味します。

実装的にはpytorchであればnn.Conv1dなどにdilationというパラメータがあるので、例えばそこに4と入力すると4要素にひとつずつ畳み込まれていきます。デフォルトは1のようですね。


self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
                           dilation=1, groups=1, bias=True, padding_mode='zeros')

2. causal convolutonal network

畳み込む対象の要素が時系列的に今見ている要素以前のものであるCNNです。普通CNNは対象が1次元データなら見ている対象の左右両方を畳み込みますが、今回扱っているのは時系列データなので見ているデータより左、つまり時間的に前のもののみを畳み込んでいます。これをcausal convolutonal networkといいます。与えられた入力から逐次的に(ストリーム的に)出力を予測する場合はこの構造は必須ですが、最後までのデータが一度にすべて与えられている場合は無理にこのcausalにする必要はなく、non-causalつまり時間的に前後の要素を畳み込んでもいいでしょう。実際に図(a)では青い線が右下には伸びておらず、真下と左下にのみ伸びています。

実装的にはnon-causalは以下で簡単にできます。

non-causal
self.conv = torch.nn.Conv1d(
    in_channels,
    out_channels,
    kernel_size,
    stride=1,
    padding=(kernel_size - 1) // 2 * dilation_size,
    dilation=dilation_size,
    groups=1,
    bias=True,
    padding_mode='zeros',
)

causalに関しては一度多めにpaddingをとりその後余分な箇所を削る形で実装します。

causal
padding_size = (kernel_size - 1) * dilation_size
self.conv = torch.nn.Conv1d(
    in_channels,
    out_channels,
    kernel_size,
    stride=1,
    padding=padding_size,
    dilation=dilation_size,
    groups=1,
    bias=True,
    padding_mode='zeros',
)
self.chomp1d = Chomp2d(padding_size)


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

ここでdilation_sizeというのが1で説明したdilated convolutonal networkにおけるdilationの値です。

3. residual convolutonal network

いわゆる、スキップコネクション、残余ネットがあるCNNで、上の図(b)でいう1x1 convだけを通っているパスがあるということです。
もとのResNetはこの例のように1x1 convすら通さないですが、これも広義のスキップコネクションといえると思います。
これがあることで勾配消失を防げるみたいな話がありますね。ただ実際の学習はこのスキップコネクションが多く適用されたパスに対して主に行われていて、あまり層を深くしても意味がなく勾配は関係ないのかな?とも思ったりします。

実装

なんと元の論文で実装したgitのリポジトリのurlを紹介してくれているのでここにそれを張ります。
やはり実装を見るのが一番腑に落ちますね。

以下は参考にしたWaveNetの実装です。

まとめ

僕は今回、音楽情報処理の論文でこのTCNに出会ったのですが、この論文の主旨の通り、時系列データに対してCNNを使いたいとなったときにこのTCNをベースにモデルを考えれば比較的いい精度が得られるのではないでしょうか!ちなみにLSTMとかに比べると驚くほど学習が早く済みます。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32