はじめに
pythonの機械学習ライブラリであるKerasは、単体では動かず、2023年8月現在はTensorflowをバックエンドとして動くようになっています。ですが2023年秋リリース予定の3.0で、再びTensorflow以外のバックエンドを使えるようになるらしいです。
この記事の要約
- Keras 3.0でバックエンドが選べるようになる
- Tensorflow、PyTorch、Jax、Numpyのいずれかを選択可能(デフォルトはTensorflow)
- Keras3.0で書いたコードは、(特定のバックエンドに依存するコードを書いてない限り)どのバックエンドに切り替えても動作する
Kerasについて
ここはKerasって何ぞやの方へ説明です。
KerasはGoogle社のFrançois Chollet氏を中心として開発されている機械学習ライブラリです。それまで複雑な処理の実装が必要であった深層学習モデルの構築や学習などを、シンプルかつ簡単にすることを目的の1つとしており、mnistなど、入門的な学習は数十行のコードで書いたりできます。近年はPyTorchに多いですが、2019年くらいまでは、生成AIを含め、多くの論文の実装でKerasが用いられてきました。
そんなKerasについて、次の項目でその歴史みたいなものを紹介します。
Kerasの変革
初期 ~ v2.3
Kerasの最初のリリースは2015年です。
初期のKerasでは、TensorflowとTheanoをバックエンドで選べるようになっており、後にはCNTK (現 Microsoft Cognitive Toolkit)が対応、オフィシャルではないですが、AmazonのMxNetなど、様々なライブラリをバックエンドで選べるようになっていました。
Tensorflowへの統合 (v2.4 ~ )
KerasはTensorflow2.0のリリースに伴って、複数のバックエンドを廃止、そしてTensorflowと統合されるようになりました1。v2.4ではGithubのKerasリポジトリはメンテナンスされなくなれ、Tensorflowのリポジトリ内で機能の更新が行われていた時もありました2が、その後のv2.6で再度分割されました。
分割後のv2.6以降ではマルチバックエンドの対応がなくなり、「Tensorflowで書く <- -> Kerasで書く」のような形で、パッケージは分かれているものの、2つで1つのライブラリのような状態が現在も続いています。
ちなみに、Keras v2.4リリースと同時期に、Kerasがバックエンドで対応していたTheanoが開発終了、CNTKも2019年に開発が止まっています3。
再びバックエンドを使えるように (v3.0 ~ )
最近、keras.ioでKeras Coreなるものがリリースされるというページが公開されました。
ぜひ公式を読んでいただきたいのですが、まとめてしまうと、以下のようなことを書いています。
- 2023年秋に新しいくKeras Coreをするよ。これは後にKeras 3.0になる
- フレームワークで使われるTensorflow、PyTorch、JAXは異なる領域で活躍しているから、異なるフレームワークでも互換性や移行が楽になる
- Keras Coreは異なるバックエンドでも書き方は(ほぼ)1通りだけになる
Kerasの思想がv3で帰ってくるなーという思いです。
1はそのままなので、2と3について背景を説明します。少し専門的な話なので、都度用語などは調べたりしていただければと思います。
フレームワークで使われるTensorflow、PyTorch、Jaxは異なる領域で活躍しているから、異なるフレームワークでも互換性や移行が楽になる
まず、元ページでは、Tensorflowは生産領域で半数以上、PyTorchは研究分野、そしてJAXは生成AIで占めているといいます。それぞれについて補足すれば、次のような感じです
- Tensorflowは生産領域についてこのページで事例紹介をしており、Googleは勿論、PayPalやIntelなどの企業で応用例があります4。
- PyTorchについては、CVPRなどの先端領域の論文でみられるリポジトリをたどっても、ほとんどがTorch実装だったりで5、ディープラーニングの領域を研究する人ならほぼ必須といっていいほど理解する必要のある言語です。
- Jaxは性能を要求する演算や大規模トレーニングで使われるイメージです。扱い自体は若干上級者向けではあります。
フレームワークが異なることのデメリットには、あるフレームワークで作成した重みは他のフレームワークとの互換性がない、という点です。例えばTensorflowで学習された重みはデフォルトでh5
というフォーマットですが、PyTorchはこれを読み込めません。逆にPyTorchではpth
というフォーマットですが、Tensorflowではこれを読み込めません6。
Keras Coreではこの互換性を解決することができ、.keras
というフォーマットに統合することで依存関係をなくすことができます。
Keras CoreはTensorflow、PyTorch、Jaxをバックエンドとして対応していて、異なるバックエンドでも書き方は(ほぼ)1通りだけになる
それぞれのフレームワークではモデル構築から学習まで書き方がかなり変わります7、そうしたものがKerasスタイル、つまり↓のような書き方に統一できます。
import os
- os.environ["KERAS_BACKEND"] = "tensorflow"
+ os.environ["KERAS_BACKEND"] = "torch"
import keras_core as keras
from keras_core.layers import Conv2D, Input, MaxPooling2D, Flatten, Dense
# MNISTデータを分割
(train_images, train_labels), (test_images,
test_labels) = keras.datasets.mnist.load_data()
# パラメータ
classes = 10
batch_size = len(train_images) // 16
epochs = 20
# 簡単なモデルを作成
model = keras.Sequential(
[
Input(shape=(28, 28, 1)),
Conv2D(32, (3, 3), activation="relu"),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation="relu"),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation="relu"),
Flatten(),
Dense(64, activation="relu"),
Dense(classes, activation="softmax"),
]
)
# 前処理
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = keras.utils.to_categorical(train_labels, classes)
test_labels = keras.utils.to_categorical(test_labels, classes)
model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)
# 学習開始
model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size)
# 結果をもとに評価
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print("Test accuracy:", test_acc)
os.environ["KERAS_BACKEND"] = ...
の部分でpytorchならtorch
、jaxならjax
とすればバックエンドが変わったまま使えるわけです。上記はMNISTの例ですが、kerasの公式はそれ以外にいろいろな例が提示されているので、興味ある方はぜひ見てみてください。
終わりに
以上がKeras Coreについてでした。最近はWEBエンジニアになったのであまり機械学習関係には触れてなかったのですが、偶然Keras Coreを見つけたので趣味とノリと勢いでこの記事を書きました。
TensorflowとKeras好きの自分にとってはうれしいニュースだったのですが、みなさんにとってはどうでしょうか。互換性という点と初心者へ向けた機能としては、Keras Coreは良いものになるのではないでしょうか。
-
https://github.com/keras-team/keras/commit/b5cb82c689eac0e50522be9d2f55093dadfba24c ↩
-
日本のPreferred Networksが開発していたChainerも2020年に開発終了したので、この頃にTensorflowとPyTorchがライブラリの二大巨頭になりました。 ↩
-
昔ですが、昔私がやっていたアルバイト先のプロダクトではTensorflowが使われていました。 ↩
-
研究領域でいうと、近年はPyTorch or JAX or PaddlePaddleが3台巨頭です。 ↩
-
onnxというライブラリを使って変換処理をかければ不可能ではないですが、工数がかかり面倒です。 ↩
-
コーディングスタイル以外に、TensorflowはDefine and Run、PyTorchとJaxはDefine by Runという方式をとっているのも原因の1つです。これらの違いについては各自調べてください。 ↩