Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
47
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

闇のkerasに対する防衛術

他のDLライブラリを勉強するのが面倒という理由でkerasで実装されていないネットワークを組んではいけないというアンチパターンの話。つまり闇のkerasに対する防衛術の話です。

kerasでゴリゴリ学習コードを書いてはいけない

kerasはLSTMが数行で書けたり、vggなどの有名なモデルが揃っている、便利なラッパーとなっていますが、kerasで実装されていない論文のコードを書くことは極力避けましょう

keras以外に慣れていなくてもです。 とっととtensorflowかpytorchを勉強してください。

理由

通常のフィードフォワードな分類or回帰のネットワーク(全結合、CNN、autoencoderなど)や既にラッパーが用意されているLSTMは瞬時に実装できますし、パラメータチューニングも簡単なので是非kerasを使いましょう

ただし以下のような場合は絶対にkerasを使わないでください。闇の技術に祟られます。

  1. 特殊な入力と複数の損失関数から最適化をするもの(ex.GAN、マルチタスクラーニング全般)
  2. adversal dataを作るなど、通常のフィードフォワードネットワークでは行わない処理をするもの(ex.VAT
  3. その他kerasでラッパーが用意されていない層のあるネットワークを組むこと(ライブラリ開発者は例外)

全て語ると相当長くなるので、今回は1と2に関してGANの学習にフォーカスして説明します。

ネットワークの構築がしんどい

model_g = generator()
model_d = discrimeter()

# dに対するfakedataの学習
model_d.trainable = False
model_d_g = Sequential()
model_d_g.add(model_g)
model_d_g.add(model_d)
model_d_g.compile(...) #パラメータは省略
model_d.trainable = True

さらにfeature matchingという手法があって(こちらがわかりやすい説明でした)、それを使う場合は中間層の出力に関するラッパーがないので

def discrimeter():
    ...
    y = ...
    feature_matching_layer = Dense(4096, name='feature_matching')(y)
    y = Dense(4096)(feature_matching_layer)
    ...
    return y, feature_matching_layer

model_g = generator()
model_d, feature_matching_layer = discrimeter()

# dに対するfakedataの学習
# 同上なので省略

# `model_d_g` と `model_d` と同様のものを `feature_matching_layer` でも作成
model_d.trainable = False
model_d_fm = feature_matching_layer
model_d_g_fm = Sequential()
model_d_g_fm.add(model_g)
model_d_g_fm.add(feature_matching_layer)
model_d_g.compile(...) #パラメータは省略
model_d.trainable = True

# feature matching
feature_matching_predict = feature_matching_layer.predict(z)

とまあ分岐する度に compile を書かないといけないです。つまり、油断したらdiscrimetorからgeneratorへの逆伝播を何度も書かないといけないうえにコードがさらに冗長になります。だるくないですか?

また、trainable パラメータを一時的に変更する必要がありますが。こいつを何かの間違えで消した時は、「あれ?学習がうまくいかないなあ」と、手違いで消したことになかなか気づけず、 レガシーの闇に飲まれるリスクがあります。

さらに、kerasの場合ネットワークが分岐したらoptimizerを分けないといけないので、 論文に対して学習の再現性が取れません 。幸いにも微分式的には対して変わりませんが、adamのパラメータが共有できないのは注意点です。気にならない人もいるかもしれませんが。

以上より冗長になり難読化してしまうので、kerasは避けましょう。設計の練度が高い人でもpytorchやchainerなどを使ってください。いい感じに抽象化されていますので。

trainのコードがだるい

kerasの場合、基本的には学習の実行は以下のように一行で書けるはずです

model.fit(X, y, batch_size=64, epochs=300, validation_split=0.33)

バリデーションデータの分割もパラメータ指定するだけで勝手にやってくれるのはとてもありがたいですね。

tensorboardを使いたい人などはコールバック関数の設定もしないといけないですが、さほど複雑ではないかとは思います。

さて、GANを書いた場合は学習の処理が特殊になるので以下のようになります。

for X in batch_generator(train):
    model_d.train_on_batch(X, [1] * batch_size)
    z = np.random.uniform(-1, 1, size=(batch_size, 100))
    generated = model.predict(z)
    model_d.train_on_batch(generated, [0] * len(X))
    model_d_g.train_on_batch(z, [1] * batch_size)

モデルごとに1バッチごとの学習を行うメソッドがありますが、従来のkerasの綺麗なログで学習状況を眺められませんし、tensorboard対応をするにも、自分で頑張って書く必要があります。当然kerasに対応させたがる物好きは少ないので、交通整備されていないリサーチの闇に飲み込まれます。

私も実際に実装しようとしましたが時間がかかりました。

まとめ

kerasは基本的な学習を抽象度低めに提供しています。通常の分類問題は秒で書けてとても頼もしいですが、実装されていないネットワークを組むのは非常に時間がかかるうえに助けてくれる人が非常に少ないです。(やりたがる物好きが少ないから)

ネットワークを自前で実装する時は、抽象度がネットワークの実装に適してあるtensorflowやpytorch、chainerを使いましょう。決してこれらを覚えるのがだるいという理由で、kerasを使っていはいけません。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
47
Help us understand the problem. What are the problem?