更新履歴
2023/12/19 文章を補足。タイトルを変更。
初めに
前回、過学習できちんと可視化出来なかったのでリベンジです。
今回は、昨今注目のefficientNETに医用画像を学習させ、gradCAMで予測根拠の可視化の実験をしていきます。
google colaboratory環境を想定した解説です。
以下のコードで、colaboratory にドライブをマウントできます。
from google.colab import drive
drive.mount('/content/drive/')
全体のコードを参照しながら、一部抜粋し解説していきます。
全体のコードを参照しながら、読んで頂くと、分かりやすいかもしれません。
全体のコード:
https://github.com/medical-ai-project/ai_notebook/blob/main/efficientNet_grad-CAM.ipynb
使用するデータセット:
https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia/data
参考にするネットワークアーキテクチャ:
https://www.kaggle.com/code/momomooo/efficientnet-chest-x-ray-classification/notebook
最終的に以下のように、CAMのヒートマップと入力画像を合成したような画像を出力します。
学習の背景を整理する
今回のデータ数は約5700枚で、これをさらに学習データ検証データ、テストデータに分割するので、比較的学習に使えるデータ数は少ないと言える。
よって、モデルはimagenetで学習済みのモデルをベースとし、新たに出力層を追加し転移学習させる。
# 転移学習に使用する、efficientNetの(imagenetで)学習済みモデルをインストールする。
!pip install efficientnet
データ拡張、データ全処理
先程も言ったとおり、学習に使用するデータセットが少ない。このような場合には、データ拡張(Data Augumation)をし過学習を防ぐ。以下一部コード抜粋。
# TensorFlowのKerasライブラリからImageDataGeneratorをインポートする。
# ImageDataGeneratorは画像データの前処理やデータ拡張のために使用される。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# トレーニングデータ用のImageDataGeneratorインスタンスを生成する。
# このインスタンスは、トレーニングデータに対して特定のデータ拡張や前処理を行う。
train_datagen = ImageDataGenerator(
rescale=1/255., # 画像のピクセル値を0から1の範囲に正規化する。
zoom_range = 0.1, # 画像をランダムに拡大縮小する(ここでは10%の範囲で変更)。
# rotation_range = 0.1, # 画像をランダムに回転させる。
width_shift_range = 0.1, # 画像を水平方向にランダムにシフトする(10%の範囲で)。
height_shift_range = 0.1 # 画像を垂直方向にランダムにシフトする(10%の範囲で)。
)
# 検証データ用のImageDataGeneratorインスタンスを生成する。
# 通常、検証データにはデータ拡張を適用しないため、ここでは正規化のみを行う。
val_datagen = ImageDataGenerator(
rescale=1/255. # 画像のピクセル値を0から1の範囲に正規化する。
)
# トレーニングデータセットを生成する。
# ImageDataGeneratorを使用して、データフレームから画像データを読み込み、前処理を適用する。
ds_train = train_datagen.flow_from_dataframe(
train_df, # 使用するトレーニングデータフレーム。
# directory=train_path, # 画像のパスがデータフレームに含まれているため、コメントアウトされている。
x_col = 'image', # 画像ファイルのパスが格納されている列の名前。
y_col = 'class', # ターゲット(ラベル)が格納されている列の名前。
target_size = (IMG_SIZE, IMG_SIZE), # 画像のサイズを指定サイズにリサイズ。
class_mode = 'binary', # クラスモードをバイナリ(2クラス分類)に設定。
batch_size = BATCH, # 1回のバッチで使用するサンプル数を指定。
seed = SEED # データ拡張をランダムに適用する際のシード値。
)
転移学習のモデルを準備する
コンピュータでどの様に微分をするのかは、一つの研究分野として成立する程奥が深い。今回、tensorflowのtf.GradientTape機能を使用すし微分するが、この機能はモデルのレイヤーがネストされると、計算グラフが繋がらなくなる問題がある。
よって、tensorflowで転移学習を行う際には、grad-camを適応するレイヤーがネストされないように気をつける必要がある。以下のように 「pretrained_model_B7.input」と「pretrained_model_B7.output」でメタ情報を取り出し、それを繋げることによりefficientNetモデルがネストされないようになる。
# EfficientNetB7 モデルをロード
pretrained_model_B7 = efn.EfficientNetB7(weights='imagenet', input_shape=input_shape, include_top=False)
# インプットとアウトプットのレイヤーのメタ情報を抜き取る。
input_tensor = pretrained_model_B7.input
output_tensor = pretrained_model_B7.output
# efficientNetのアウトプットとカスタマイズで追加した層とを繋げる。
x = Conv2D(filters=128, kernel_size=(3,3), activation='relu')(output_tensor)
x = GlobalAveragePooling2D()(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
x = Dense(1)(x)
x = BatchNormalization()(x)
x = Activation('sigmoid')(x)
...一部省略。後に解説する。
# efficientNetのインプットのメタ情報をModelに渡す。
model = Model(inputs=input_tensor, outputs=x)
model.compile(loss=binary_crossentropy,
optimizer=Adam(learning_rate=1e-3),
metrics=[metrics.AUC(name='auc'), 'accuracy'])
難しく感じると思うが、model.summary()を利用してモデルがネストされていないか確認できる。model.summary()した際に、数行だけしか表示されない場合、モデルはネストされてしまっている。efficientnet-b7が展開されて表示されている場合は成功で、添付した全体コードの出力のように冗長に表示される。
学習済みのモデルの特定のレイヤーだけを、再学習させる
予め学習されたパラメータは、新たに学習させるデータの特徴量を抽出するとは限らない。imagenetは約1500万枚の画像を数千種以上のカテゴリに分類する為に最適化されているが、通常専門的な医用画像の様な画像は含まない。
そこでcamを適応するレイヤーに対して新たなデータを再学習し、新たなデータ(今回の場合は医用画像)の特徴を抽出する様にする。
せっかくimagenetで学習させた既存のパラメータを再学習させる事は、転移学習の利点を破壊する事であるから、再学習させるレイヤーはCAMを適応するレイヤーのみに限りたい。
今回の様に、新たなデータセットが少ないと、imagenetで学習されたパラメータを再学習させる程、基本的には精度が低下する現象が確認される。
(通常こうしたファインチューニングは、新たなデータセットに対する認識精度を向上させる目的で行われる。今回の様に、データセットが少ない場合には認識精度がむしろ低下する為、通常行われない。認識精度の低下と引き換えに、CAMの可視化を改善させる目的で行う事は特異なケースで注目に値する。)
以下はimagenetで学習済みのefficientNetの1層のみを再学習させた場合の認識精度である。
以下はimagenetで学習済みのefficientNetの2層を再学習させた場合の認識精度である。上記の1層の場合と比較して、数%ほど認識精度が低下している。なお、6層以上のレイヤーを再学習した場合には、精度は5割を切った。再学習させるレイヤーは最小限に抑えたい。
レイヤーをforで回して、再学習したいレイヤーが見つかったら、trainableをTrueにする。
他のレイヤーはFalseにし、再学習しないように設定する。
今回は、"block7c_project_conv"を再学習可能に設定した。
# 先程省略されたコード。
set_trainable = False
for layer in pretrained_model_B7.layers:
if layer.name == 'block7c_project_conv':
set_trainable = True
# if layer.name == 'top_conv':
# set_trainable = True
layer.trainable = set_trainable
【pythonでの微分】gradientの計算にtf.GradientTapeを用いる。
gradCAMなので、勾配計算をしないといけない。ここで先ほども言った通り、tf.GradientTapeの仕組みを使用する。これは自動微分の仕組みで、以下の様にwith内で行った計算を、計算グラフとして記録する。この計算グラフを使用して、誤差を伝播し勾配を計算する(誤差逆伝播という。詳しくはゼロから作るDeep Learningで学べる)。
入力画像はNumpyではなく、テンソルに変換する。
# 入力する画像のインデックス指定
idx = 556
input_img = ds_test[idx][0]
# NumPy配列をテンソルに変換
input_img_tensor = tf.convert_to_tensor(input_img.reshape(1, 600, 600, 3))
以下で、grad-CAMを適応するレイヤーを指定している。指定したレイヤーの出力(特徴マップ)を出力する中間モデルを作成している。
target_layer = model.get_layer("block7c_project_conv")
intermediate_model = Model(inputs=[model.inputs], outputs=[target_layer.output, model.output])
その後、with tf.GradientTape() as tape:内で行われた計算は、計算グラフとしてtapeに記録される。
with tf.GradientTape() as tape:
tape.watch(input_img_tensor) # ここでテンソルをwatch
conv_output, predictions = intermediate_model(input_img_tensor)
class_idx = np.argmax(predictions[0])
loss = predictions[:, class_idx]
gradCAMの計算
先程tapeに記録した情報を使用して、tape.gradientで勾配を計算することが出来る。
最後にmatplotlibで表示するために正規化(最大値で割る)を行っている。
grads = tape.gradient(loss, conv_output)[0]
# グローバル平均プーリング
# ここでのweights(重み)は、通常の機械学習の文脈における重みではなく、
# Grad-CAMの文脈で特定の特徴マップがどれだけ重要であるかを示す量。
weights = np.mean(grads, axis=(0, 1))
# Grad-CAMの計算
# conv_output[0]はバッチのindex0番目に対応している。今回は1枚のみなので[0]のみ存在している。
cam = np.dot(conv_output[0], weights)
cam = np.maximum(cam, 0) # ReLU
cam = cam / cam.max() # 正規化
CAM(ヒートマップ)を出力
先程、計算したcamをmatplotlibで可視化し、結果を確認する。
plt.subplot(121)
plt.imshow(cam, cmap="jet")
plt.subplot(122)
plt.imshow(input_img[0], cmap='gray')
plt.show()
CAMと入力画像を合成
生成されたCAMは、入力画像よりも画素数が落ちる。そこでscipyライブラリのzoomモジュールを使用して、簡易的なアップサンプリングを行う。
from scipy.ndimage import zoom
# 50x50の特徴マップを224x224にアップサンプリング
zoom_factor = 600 / 19 # 600: ターゲットのサイズ, 19: 元のサイズ
cam_resized = zoom(cam, zoom_factor)
# プロット
plt.subplot(131)
plt.title("Original Image")
plt.imshow(input_img[0], cmap="gray")
plt.subplot(132)
plt.title("grad_CAM")
plt.imshow(cam_resized, cmap="jet")
plt.subplot(133)
plt.title("Combined")
plt.imshow(input_img[0], cmap="gray")
plt.imshow(cam_resized, cmap='jet', alpha=0.5)
plt.show()
カスタマイズしたカラーマップの使用
CAMが反応しなかった箇所を無色透明にし、見やすくする。
そのためにカラーマップをカスタマイズする。
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
cmap = plt.get_cmap('jet')
cmaplist = [cmap(i) for i in range(cmap.N)]
# 範囲を変更
shift_amount = 10
new_cmaplist = [(1, 1, 1, 0.0) for _ in range(shift_amount)] + cmaplist[:-shift_amount]
for i in range(45):
new_cmaplist[i] = (1, 1, 1, 0.0) # (R, G, B, Alpha)
cmap_custom = mcolors.LinearSegmentedColormap.from_list('custom_cmap', new_cmaplist, cmap.N)
ある一定の値以下を、無色透明にするカラーマップが出来た。
作成したカラーマップを適応する。
import matplotlib.colors as mcolors
# 元の画像
plt.imshow(input_img[0], cmap='gray')
# CAMを透過してオーバーレイ表示
plt.imshow(cam_resized, cmap=cmap_custom, alpha=1, vmin=0)
plt.show()
まとめ
最終的に出力された画像は、横隔膜あたりに反応しているのでしょうか?僕は放射線科医じゃないのでよく分かりませんが、肺炎の診断にCPangle(肋横隔膜角)が使われることがあるので、そのへんに反応しているのかと思われます。
こういう医用画像データセットってオープン化が進んでいるとはいえ、まだまだ少ないような気がします。
医用画像診断は中国などが強いですが、個人情報の取り扱いが比較的緩いのが後押ししているのだと思われますね。
僕も病気でCT画像を撮ったことがありますが、画像データを蓄積、活用している様子はなかったですね。そのへんを日本でももっと整備してくれると、研究しやすいのですが。。。
参考文献等
[1].EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,Mingxing Tan, Quoc V. Le.https://arxiv.org/abs/1905.11946 .28 May 2019.
[2].Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra.https://arxiv.org/abs/1610.02391 .3 Dec 2019 .
[3].【初学者必読】Google Colaboratory とは?使い方・便利な設定などをわかりやすく解説!.AI Academy運営事務局.https://aiacademy.jp/media/?p=1037 .参照日2023/11/15.
[4].【Google Colaboratory】Google ドライブにマウントし、ファイルへアクセスする方法.二ノ宮.https://blog.kikagaku.co.jp/google-colab-drive-mount .参照日2023/11/15.
[5].転移学習とは?AI実装でよく聞くファインチューニングとの違いも紹介.https://aismiley.co.jp/ai_news/transfer-learning/ .AIsmiley編集部.参照日2023/11/15.
[6].自動微分と勾配テープ.TensorFlow.https://www.tensorflow.org/tutorials/customization/autodiff?hl=ja .参照日2023/11/15.
[7].A Complete Guide to Data Augmentation.datacamp.https://www.datacamp.com/tutorial/complete-guide-data-augmentation .参照日2023/11/15.