1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[TensorFlow/Keras] Keras 3時代(TF2.16~)のカスタムモデルの書き方を調べてみた

Last updated at Posted at 2024-11-04

はじめに

TensorFlow 2系では、モデルをKerasという別のライブラリ(ただしKeras 2時点では事実上TensorFlowと一体化)を使って記述することが標準的な方法となっています。
このKeras、TensorFlowのラッパーのような立ち位置だったのですが 1、2023年にリリースされたKeras 3ではTensorFlow以外のバックエンドを選べるようになりました。TensorFlow 2.16以降はKeras 3で動作するようになり、それに伴って、カスタムモデルの書き方も少し変わってきています。このあたりの調整にコツが必要だったのでメモします。

ちなみに、Keras 3を使うとJAXというバックエンドを選べるようになり、学習を高速化できるらしいので、最後に高速化の検証もやってみます。

検証環境

  • Ubuntu 22.04.5 LTS
  • Python 3.10.12
  • TensorFlow: 以下の2バージョンで検証
    • 2.14.1(Keras 2.14.0)
    • 2.16.2(Keras 3.6.0)

基本のサンプルコード

簡単のため、MNISTの学習で試してみましょう。
以下のように、Keras標準のレイヤーと Sequential を使って記述したモデルは、基本的にTensorFlow 2.15まで(Keras 2)でも2.16以降(Keras 3)でも動作します。
入力形状の与え方としては、最初に Input レイヤーを追加する方法と、入力側から見た最初のレイヤーに input_shape 引数を与える方法の2種類がありますが、Keras 3では後者は非推奨となり、Warningが発生します。

mnist-simple.py
import keras
from keras.layers import Input, Flatten, Dense

# MNISTデータセットのロード
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# データの前処理
x_train, x_test = x_train / 255.0, x_test / 255.0  # 0-255の値を0-1に正規化

# モデルの作成
model = keras.models.Sequential([
    Input((28, 28)),
    Flatten(),
    # Flatten(input_shape=(28, 28)), # input_shapeを与える方法はKeras 3では非推奨
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# モデルの学習
model.fit(x_train, y_train, epochs=5)

Lambda レイヤーで処理をユーザ定義する場合

Lambda レイヤーを使って自前で定義した処理がある場合、少し話が変わります。
今回はわざわざやる意味はありませんが、活性化関数のReLUを分離して以下のように Lambda で書いたとしましょう。

mnist-lambda.py
import tensorflow as tf
import keras
from keras.layers import Input, Flatten, Dense, Lambda

# MNISTデータセットのロード
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# データの前処理
x_train, x_test = x_train / 255.0, x_test / 255.0  # 0-255の値を0-1に正規化

# モデルの作成
model = keras.models.Sequential([
    Input((28, 28)),
    Flatten(),
    Dense(128),
    Lambda(lambda x: tf.maximum(x, 0)),
    Dense(10, activation='softmax')
])

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# モデルの学習
model.fit(x_train, y_train, epochs=5)

このモデル、一見問題なく動作するように見えるのですが、TensorFlow 2.16以降(Keras 3)で後述するJAXのバックエンドを使って実行するとエラーが発生します。

KERAS_BACKEND=jax python mnist-lambda.py # インストール方法は後述
# ログ中略
RuntimeError: Unable to automatically build the model. Please build it yourself before calling fit/evaluate/predict. A model is 'built' when its variables have been created and its `self.built` attribute is True. Usually, calling the model on a batch of data is the right way to build it.

これは Lambda の中をKerasではなくTensorFlowの演算で書いているためです。バックエンドが変わるとTensorの実体が変わってしまう(tf.Tensor ではなくなる)ので、TensorFlowの演算を直接使うことはできないのです。Kerasの演算はバックエンドの差を吸収するので、以下のOKパターン3種類はすべてJAXバックエンドでも使えます。

# TensorFlowバックエンド以外ではNG
Lambda(lambda x: tf.maximum(x, 0))
# OK
Lambda(keras.activations.relu)
Lambda(lambda x: keras.activations.relu(x))
Lambda(lambda x: keras.ops.maximum(x, 0))

もちろん以下でもOKです(活性化関数を分けたい場合、こう書くのが一般的でしょう)。

# OK
keras.layers.Activation('relu')

バックエンド間の互換性・移植性を考えると、なるべくKerasの演算を使って書くのが望ましいといえます(tf を使わないで済むならば、それに越したことはないということです)。Keras 2ではバックエンドがTensorFlow一択だったので、混ぜて書いても特に問題にはなりませんでしたが、今後はぼちぼち意識していく必要があるでしょう。

Model のサブクラスとしてモデルを定義する場合

複雑なモデルを定義したり、モデルの再利用をしやすくしたりする目的で、keras.Model クラスのサブクラスとしてカスタムモデルを定義することができます。call() の中で独自の演算を使うことができますが、その場合の注意点は Lambda の場合と同じです。

mnist-subclass.py
import keras
from keras.layers import Flatten, Dense

# MNISTデータセットのロード
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# データの前処理
x_train, x_test = x_train / 255.0, x_test / 255.0  # 0-255の値を0-1に正規化

# サブクラス化されたモデルの定義
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

# モデルの作成
model = MyModel()

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# モデルの学習
model.fit(x_train, y_train, epochs=5)

これでモデル自体は使うことができます。

モデルのsummaryが正しく表示されない

先ほどのコードでは、model.summary() を実行した場合に各レイヤーの出力形状やパラメータ数が表示されないという問題が起こります。

model.summary()
Model: "my_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ flatten (Flatten)                    │ ?                           │     0 (unbuilt) │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ ?                           │     0 (unbuilt) │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense)                      │ ?                           │     0 (unbuilt) │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘

これは MyModel のインスタンスを作成しただけでは Dense などのパラメータの形状が確定せず、初期化されていないためです(入力形状が分かって初めて、パラメータの形状が確定し、初期化することができる)。つまり、入力形状を与えてあげればよいです。2

MyModel クラスのコンストラクタを以下のように変更します。Inputレイヤーを作って、モデルに一度通してあげるとよいです。

# サブクラス化されたモデルの定義
class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')
        # モデルの構築: 以下2行を追加
        input_layer = Input((28, 28))
        self(input_layer)

    def call(self, inputs, training=False):
        # 略

# モデルの作成
model = MyModel()

model.summary()
Model: "my_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ flatten (Flatten)                    │ (None, 784)                 │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ (None, 128)                 │         100,480 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense)                      │ (None, 10)                  │           1,290 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘

ちなみに、Keras 2ではこの方法だとOutput Shapeに具体的な値が表示されず multiple となるようです。以下のようにすると良さそうです。

# Keras 2で動作する方法
class MyModel(keras.Model):
    def __init__(self):
        # super().__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')
        # モデルの構築
        input_layer = Input((28, 28))
        super().__init__(inputs=[input_layer], outputs=self.call(input_layer)) # ここを変える

カスタムメトリクスの定義

カスタムモデルでは、独自のロス(損失)関数やメトリクス 3 を定義することができます。Keras 2では、モデルやレイヤーの処理の中でこれらを定義できましたが、Keras 3で定義できるのはロス関数のみで、メトリクスを定義することができなくなりました。

例えば、以下のようにモデル内でロスを追加するコードはKeras 3でも動作します。4

# サブクラス化されたモデルの定義
class MyModel(keras.Model):
    def __init__(self):
        # 略

    def call(self, inputs, training=False):
        if self.built:
            self.add_loss(1e-2 * keras.ops.sum(keras.ops.square(self.dense1.weights[0])) / 2) # L2ロスを追加
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

一方、仮にL2ロスにあたる値を学習途中で確認したいとしましょう。Keras 2では model.add_metrics() 5 というメソッドがあり、

self.add_metric(1e-2 * tf.math.reduce_sum(tf.math.square(self.dense1.kernel)) / 2, name="l2_loss")

のようにしてカスタムメトリクスを登録することができました(TensorFlowのドキュメントへの記載は、一足早くTensorFlow 2.14から消えているようです)。しかし、Keras 3ではこのメソッドが削除されてしまい、メトリクスは model.compile() でしか指定できないようになってしまいました。

モデルパラメータに依存するメトリクスの登録方法

カスタムメトリクスは (y_pred, y_true) を引数に取る関数か、keras.metrics.Metric のサブクラスを指定することになります。例えば、あるレイヤーのL2ロスをメトリクスとして登録するには、以下のようになります。

# dense1のL2ロスを返す関数。y_true, y_predは使わない
def l2_loss(y_true, y_pred):
    return 1e-2 * keras.ops.sum(keras.ops.square(model.dense1.kernel)) / 2

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', l2_loss])

ただし、この方法はJAXバックエンドだとうまくいかないようです。根本的な解決ではありませんが、JAXでも動く方法を次節で述べます。

複雑なメトリクスの定義

先ほどは重みパラメータだけに依存するものでしたが、時には中間層の値の統計量など、入力に依存するメトリクスを見たいケースもあるでしょう。このような場合でも、Keras 2であればモデルやレイヤーの call() の中で self.add_metrics() を呼び出せば一発だったのですが、Keras 3の場合は工夫が要ります。

一つの方法として、メトリクス用のモデル出力を追加することが挙げられます。例えば、中間層の出力から各次元の最大値を計算し、メトリクスとして得る場合を考えてみます。以下のように、その値を本来の出力とともにモデルの出力にして、model.compile() では各出力に対するロス・メトリクスを記述します。

mnist-custom-loss.py
import keras
from keras.layers import Input, Flatten, Dense, Lambda

# MNISTデータセットのロード
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# データの前処理
x_train, x_test = x_train / 255.0, x_test / 255.0  # 0-255の値を0-1に正規化

# サブクラス化されたモデルの定義
class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')
        # モデルの構築
        input_layer = Input((28, 28))
        self(input_layer)

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        x = self.dense1(x)
        # カスタムメトリクス: 中間層のユニットの値の最大値
        custom_metrics = keras.ops.max(x, axis=-1, keepdims=True)
        x = self.dense2(x)
        return [x, custom_metrics] # カスタムメトリクスを出力に追加

model = MyModel()
model.summary()

def custom_metrics(y_true, y_pred):
    return y_pred

# モデルのコンパイル
model.compile(optimizer='adam',
              loss=['sparse_categorical_crossentropy', None], # 各出力に対応するロス
              metrics=[['accuracy'], [custom_metrics]]) # 各出力に対応するメトリクス

# モデルの学習
model.fit(x_train, [y_train, y_train], epochs=5) # 2つ目の y_train はダミー

この方法は、JAXバックエンドの場合でも動作します。先ほどのL2ロスを出力する問題にも応用できますが、最初の次元はバッチサイズになっている必要があるので、broadcast_to() で勝手に形状を作ってしまいます。

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        # カスタムメトリクス: L2ロス
        custom_metrics = 1e-2 * keras.ops.sum(keras.ops.square(self.dense1.weights[0])) / 2
        # バッチサイズの次元を追加
        custom_metrics = keras.ops.broadcast_to(custom_metrics, [keras.ops.shape(inputs)[0], 1])
        return [x, custom_metrics] # カスタムメトリクスを出力に追加

JAXバックエンドのパフォーマンス計測

最初に話題にしていたJAXのパフォーマンスを確認してみます。最初に挙げた mnist-simple.py をもとにして、学習の開始・終了時の時刻情報を表示する処理を追加します。

import time

# モデルの学習 一応バッチサイズも明示的に指定しておく
model.fit(x_train, y_train, epochs=5, batch_size=32, callbacks=[
    keras.callbacks.LambdaCallback(
        on_train_begin=lambda logs: print("BEGIN", time.time()),
        on_train_end=lambda logs: print("END", time.time()),
    )
])

CPU

VirtualBox上にUbuntuの環境を作って動作させました。ホストOSのCPUは AMD Ryzen 5 7530U です。

バージョン 時間 [sec]
Keras 2 20.80
Keras 3 (TensorFlow) 22.80
Keras 3 (JAX) 14.79

JAXを使うことで圧倒的に速くなっています。

GPU

Google Colab上で動かしてみます。「ランタイムのタイプを変更」から、T4 GPU を選択します。

!pip install tensorflow==2.14.1  # Keras 2の検証用
!pip install tensorflow==2.16.2  # Keras 3の検証用

のようにして、TensorFlowのバージョンを調整します。そのうえで、mnist-simple.py のコードを実行します(内容をColabのコードセルにコピペして実行)。

JAXバックエンドを使うには、セッションを一度リセットして

import os
os.environ['KERAS_BACKEND'] = 'jax'

を実行してから、Kerasをインポートします。

バージョン 時間 [sec]
Keras 2 26.75
Keras 3 (TensorFlow) 20.30
Keras 3 (JAX) 21.50

Colab環境に限れば、JAXが最速ではないようです。実機のGPUだとまた話が変わるかもしれませんし、演算の種類や環境依存の要素もありそうなので、実際のタスクで比較してみるのがよさそうです。

その他

Keras 3で定義される演算は、以下のドキュメントにまとまっています。まずは、ここに含まれる演算でモデルを書けるか検討するとよいでしょう。
https://keras.io/api/ops/

  1. もっと遡ると、別のバックエンド(Theano)のラッパーだった時期もありました。

  2. もっとも、これはKeras 2にもあった話です。

  3. 学習・推論に直接関与しない評価指標。正解率・精度(Accuracy)など。

  4. L2ロスを指定したければ、Dense のパラメータに kernel_regularizer=keras.regularizers.L2(l2=1e-2) のように指定するのが一般的でしょう。あくまで例ととらえていただければ幸いです。

  5. より正確には Layer に対して定義されており、ModelLayer のサブクラスであるため Model のメソッドとしても使用可能。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?