LoginSignup
9
14

More than 3 years have passed since last update.

Tensorflow 2.x で出来ること・出来ないこと・お作法

Last updated at Posted at 2020-01-13

Tensorflow 2.x なんもわからん

TF 2.x で手法実装を行なった結果、炎上している人です。TF 2.x が出てきてから一年位経ちましたが、現在までに TF 2.x で出来たこと、出来てないこと、また最近見えてきたお作法についてまとめます。

これ出来るよ?とかこのお作法間違ってない?みたいなのがあったらご指摘なり修正なり下さい。今後この クソ フレームワークで苦しむ人が救われます。

Tensorflow 2.x で出来ること

PyTorch、Keras っぽい書き方

Eager Execution が出来るようになったので、 tf.session は廃止され、インタプリタで tf の関数を実行できるようになりました。これによってデバッグが容易になった他、実行が簡単になったそうです。

また keras と統合する流れのために、 tf.keras.layer を用いた実装 (PyTorch でいう nn.Module ?)が推し進められるようになりました。その一方で既存の関数的な実装は tf.nn 以下に集約されるようになり、例えば BatchNormalization は tf.nn.batch_normalization といった風になっています。

TPUの利用

TPU の利用がより便利になっているようです。このあたりの記事 がすごく読みやすいのでぜひご一読下さい。

Tensorflow 2.x で出来ないこと

BatchNormalization

Batch Normalization 出来ないなんてそんなわけwwwと思うかもしれませんが、案外このレイヤーは穴が多くて、ちゃんとした実装を探すのが困難です(というかちゃんとした実装まだないんじゃないかなぁ…)

特に何が問題かというと、 複数GPU / TPU での演算に対応できない 、という問題でして、こいつは PyTorchMxNet なんかでは実装されているんですが、TF 2.x はまだ実装できていません。MultiGPU/TPU を horovod に頼っていた時期 (TF 1.x 系)はそっちの方のサポートから実現していましたが、変に TF 内部で実装しようとした結果 後回しになっています

この問題は TF 2.x 系の界隈全体で頭を悩ませているらしく、PyTorch へ脱出する研究者が現れるほど 厳しい問題になっています。

現在開いている最新の Issue は多分ここ なので奮ってご参加下さい(自分も抉っています)。

  • 2020/2/3 一応 レイヤー自体は追加されました。しかしテストがされているのか怪しいのであまり期待しないほうが良いです。

  • 2020/03/17
    Tensorflow 2.2.0 の目玉機能として、SyncBatchNormalization Layerが追加されました!!!

Multi GPU/TPU の演算 (WIP)

 一応 TF 2.x では複数 GPU / TPU で演算できるよ!と言ってはいますが、その実情は 公式 にあるように Experimental support だらけです。なのでコロコロ関数が変わる可能性が高く、後方互換性を期待することが難しいです。

変数の初期バッチデータによる初期化

 深層学習では稀に用いられる初期バッチデータでの初期化ですが、こちらも観測している限り出来ません。やる方法としては、外部から初期バッチに応じて直接レイヤーの重みを書き換える、というような手法で、これには初期化を行う外部関数の設計と、訓練時にその外部関数を呼び出す仕組みを設計する必要があるため、通常の tf.keras.Modeltf.keras.Sequential を用いることが難しくなります (少なくとも fit() が使えないので custom training loop で訓練する必要があります)。()

  • 2020/03/17 できないので、自分でやる方法を作りました。 但し訓練パラメータに無理やり初期化に関する変数を組み込んでいるので、モデルサイズを正確に計算できなくなる、というバグを抱えています。
class IdentityWithInit:
   def build(self, input_shape: tf.TensorShape):
        self.initialized = self.add_weight(
            name="initialized",
            dytpe=tf.bool,
            trainable=False 
        )
        self.initialized.assign(False)
        self.built = True
   def initialize_parameter(self, x: tf.Tensor):
        tf.print("initialized {}".format(self.name)) 
        pass

   def __init__(self):
       super()__init__()

   def call(self, x:tf.Tensor):
       if not self.initialized:
           self.initialize_parameter(x)
           self.initialized.assign(True)
       return x

勾配計算 (2.0.0)

2.0.0では、特殊なケースで勾配計算が 事故るようです。(正気か?と思いましたがどうやらガチらしいです)

しかし 2.1.0 では修正されました。(手元でもテストしてみましたが修正されていました。)

スラッシュの利用

稀に Python ですごく長い文を書いた時に用いる "\" ですが、Tensorflow は Python の構文解析木を使っていない(?)らしいので、これを用いることは出来ません。https://github.com/tensorflow/tensorflow/issues/35765

駄目な例
variable * decay * \ 
 lr
良い例
(variable * decay * 
lr)

Python 3.8 の利用

Python 3.8 は 2020/02 の中旬からサポートが始まるようです。(ソースビルドをすることで試すことはできるようです)

Tensorflow 2.x のお作法

Task クラスの作成

これは最近気づいたことなのですべてを把握はしていないんですが、どうやら公式のいくつかの実装を見る限り、モデルのクラスとは別に、訓練や推論を行うクラスを別に作るのがおすすめらしいです。

基本的に実装する関数は、訓練、テストである train,test(eval) そして xxx_step となっており、 xxx_step に関しては1バッチを処理する関数で tf.function デコレータで囲っておくパターンがよく見られます。

class MyTask:
  def __init__(self, args):
     ...
     self.loss = tf.metrics.Mean(name='loss', dtype=tf.float32)

  def train(self):
    @tf.function
    def train_step(x: tf.Tensor, y: tf.Tensor):
      ...
      _y = self.model(x)
      loss = loss_fn(_y, y)
      self.loss(loss)

    for epoch in range(self.epochs):
      for x, y in tqdm(self.train_dataset):
        train_step(x, y)
      for x, y in self.val_dataset:
        val_step(x, y)
      print('EPOCH {} train: loss {} / val: loss {}'.format(epoch + 1, 
         self.loss.result(), self.val_loss.result()))
      self.loss.reset_states()
      self.val_loss.reset_states()    

  def test(self):
    @tf.function
    def test_step(x: tf.Tensor, y: tf.Tensor):
      ...

custom Layer の **kwargs

Tensorflow のクラスを書いていてなんだこの引数?と思っていたんですが、こいつは思った以上に重要な引数で、例えば BatchNormalization 層や Dropout 層の推論・訓練を切り替える場合に非常に重宝します。

例えば

class CustomLayer(Layer):
  def __init__(self):
    self.conv = Conv2D(...)
    self.bn = BatchNormalization(...)
    super().__init__()

  def build(self, input_shape):
    super().build(input_shape)

  def call(self, x: tf.Tensor, **kwargs):
    y = self.conv(x, **kwargs)
    y = self.bn(y, **kwargs)

cl = CustomLayer()

とした時に cl(x, training=False) とすることで BatchNormalization を推論モードで実行することが出来ます。

9
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
14