18
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

マルチラベル分類におけるGrad-CAM

Posted at

#はじめに
Grad-CAMが行われているのは主にOne-hotベクトルのラベル分類問題が多く、マルチラベル分類におけるGrad-CAM実装はあまり見たことがありません。正解の値が複数の場所でゼロでないようなマルチラベル分類の場合、各ラベルにおける特徴量をGrad-CAMで視覚化するのにどうすればいいかやってみました。

探した限りではFashion-MNISTの多クラス分類の可視化などがあります。(しかし個人的にはCAMマップのバイアス項が大きいのが気にかかる)
image.png

#結果:
rdkitの構造式出力画像を入力にして各原子数を予測するモデルを流用しました(以前記事)。
Grad-CAMの解像度の問題からモデルをXceptionからVGG16に変更しました。
マルチラベル分類じゃなくてマルチ回帰問題じゃんっていう突っ込みがありそうですが、マルチラベル分類でもおそらく同じの筈だと思います。
すると、以下の通り各原子の特徴量を可視化することができました。
gradcam_test19_0.jpg gradcam_test19_1.jpg
:C(炭素),                H(水素)
gradcam_test19_2.jpg gradcam_test19_3.jpg
:N(窒素),                O(酸素)
gradcam_test19_4.jpg gradcam_test19_bias.jpg
:S(硫黄),                バイアス

#計算手順:
以下の三つが通常のGrad-CAMと異なります。
##①見たいラベルの値をゼロにする
通常Grad-CAMはOne-hotベクトルの分類問題の場合、正解ラベルにどこか1つが1で、あとすべてがゼロのOne-hotベクトルをy_trueとして与えます。
ところがマルチラベル問題で一つのラベルが正しい値で他がすべてゼロだとGrad-cam画像は正しい値になりませんでした。窒素のみを選んだ場合の画像から何となく想像がつくかもしれませんが窒素以外の全体の特徴量が足しあわされたような画像が出力されました。
これを改善するには見たいラベルのみゼロで他すべて正しい値として、y_trueを定義してやりました。
gradcam_test19_0_de.jpg gradcam_test19_2_de.jpg
:C(炭素)のみ正しい値,          N(窒素)のみ正しい値

gradcam_test19_0.jpg:C(炭素)のみゼロ、他すべて正しい値

##②損失関数は通常のGrad-CAMにマイナスを掛ける
通常のGrad-CAMの損失関数にマイナスを掛ける必要がありました。

##③バイアスを差し引く
この項は必須ではないですが、やった方が結果が改善しました。
y_trueとy_predが近い値だとよいのですが、y_predの予測が悪いとバイアスが乗ってくることがありました。
これを回避するにはあらかじめ求めたバイアス項を各Grad-CAMマップから差し引くと良いようでした。
gradcam_test19_bias.jpg

#まとめ:
マルチラベルにおけるGrad-CAMを表示しました。One-hotベクトルの単分類問題と比べ、見たいラベルの値のみをゼロにして損失関数にマイナスを掛けると良いようでした。
One-hotベクトルの場合と手順は違いますが、勾配の基準をどこに取るかで多分やってることは同じなんじゃないでしょうか。

#コード:

structural_formula_train_gradcam.py
# coding: cp932

import keras
from keras.layers import Input, Concatenate, GlobalAveragePooling2D
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.preprocessing.image import img_to_array, array_to_img

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os.path
import re
import tensorflow as tf
from keras import backend as K
import cv2

from rdkit import Chem
from rdkit.Chem import Draw, rdDepictor, rdMolDescriptors

max_MW = 400 # MWの最大値
load_weights_path = './weight_param.hdf5' # 重みファイルのパス

#SMILESの読み込み
df = pd.read_csv('sdf_data.csv')
SMILES = df['CAN_SMILES'].values 

# arrayに変換
SMILES = np.asarray(SMILES)
SMILES_train, SMILES_test = train_test_split(SMILES, test_size=0.30, random_state=110)

print(SMILES_train.shape, SMILES_test.shape)

# 読み込みデータのメモリ開放
del df, SMILES


class DataGenerator():
    def __init__(self, X_input):
        self.X_input = X_input
        self.atom = ['C', 'H', 'N', 'O', 'S', 'P', 'Si', 'Na', 'F', 'Cl', 'Br', 'I']
        self.reset()
    def reset(self):
        self.X = []
        self.Y = []
    def cal_img(self, mol):
        # 構造式画像を生成
        img = Chem.Draw.MolToImage(mol, size=(300, 300))
        return img_to_array(img)[:,:,:3]
    def chem_count(self, input):
        atom = ['C', 'H', 'N', 'O', 'S', 'P', 'Si', 'Na', 'F', 'Cl', 'Br', 'I']
        count = {'C': 0, 'H': 0 ,'N': 0, 'O': 0, 'S': 0, 'P': 0, 'Si': 0, 'Na': 0, 'F': 0, 'Cl': 0, 'Br': 0, 'I': 0}
        
        input = input + 'Z' # 末尾に英語大文字を追加
        for i in range(len(atom)):
            if re.search(atom[i] + '\d{1,3}[A-Z]', input):   #原子+数字(1~3桁)+英語大文字を含む場合
                 iterator = re.finditer(atom[i] + '\d{1,3}[A-Z]', input)
                 for match in iterator:
                     count[atom[i]] = int(match.group()[len(atom[i]):-1])
            elif re.search(atom[i]  + '[A-Z]',input):       #原子+英語大文字を含む場合
                 count[atom[i]] = 1
        return count

    def make_structural_image(self, batch_size=10):
        while True:
            for i in range(len(self.X_input)):
                # SMILES形式をmol形式に変換
                mol = Chem.MolFromSmiles(self.X_input[i], sanitize=True)
                if not mol is None: # molファイルがNoneなら弾く
                    # 二次元位置を計算
                    rdDepictor.Compute2DCoords(mol)
                    
                    # mol質量を計算
                    MW = Chem.rdMolDescriptors._CalcMolWt(mol)
                    # mol形式から分子式を計算
                    chem_Formula = Chem.rdMolDescriptors.CalcMolFormula(mol)
                    # 分子式から原子数をカウント
                    count = self.chem_count(chem_Formula)
                    
                    # X,Yを追加
                    if MW < max_MW and count['C'] > 1.0: # 炭素原子が2個以上含まれるmol質量400以下の化合物であること。
                        (self.X).append(self.cal_img(mol))
                        (self.Y).append([count[self.atom[i]] for i in range(len(self.atom))])
                    # X,Yの値を返す
                    if len(self.X) == batch_size:
                        inputs = np.asarray(self.X)
                        outputs = np.asarray(self.Y)
                        self.reset()
                        yield inputs, outputs

# DataGeneratorを生成
train_datagen = DataGenerator(SMILES_train)
test_datagen = DataGenerator(SMILES_test)
batchsize = 16

# CNNを構築(VGG16)
from keras.applications.vgg16 import VGG16

base_model = VGG16(weights=None,input_shape=(300,300,3), include_top=False, pooling='avg')

x = base_model.output
y = Dense(12, activation='linear')(x)

model = Model(base_model.inputs, y)


# モデル表示
model.summary()

# コンパイル
model.compile(loss='mean_squared_error',
              optimizer='Adam',
              metrics=['accuracy'])

# 保存した重みを読み込む場合
#if os.path.isfile(load_weights_path):
#    model.load_weights(load_weights_path)
#    model.load_weights('./weights_189-0.03-0.03.hdf5')

# コールバック
callbacks = []

# 学習率の低減
base_lr = 1e-3 / 10.0
lr_decay_rate = 1 / 3
lr_steps = 4
epochs = 200
callbacks.append(keras.callbacks.LearningRateScheduler(lambda ep: float(base_lr * lr_decay_rate ** (ep * lr_steps // epochs))))

# モデル出力
fpath = './weights_{epoch:03d}-{loss:.2f}-{val_loss:.2f}.hdf5'
callbacks.append(ModelCheckpoint(filepath = fpath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'))


# 学習
history = model.fit_generator(
    generator=train_datagen.make_structural_image(batch_size=batchsize),
    steps_per_epoch=int(len(SMILES_train) / batchsize / 100),
    epochs=epochs,
    verbose=1,
    validation_data=test_datagen.make_structural_image(batch_size=batchsize),
    validation_steps=100,
    callbacks=callbacks)


# モデルと重みの保存
model.save_weights('weight_param.hdf5')
model.save("model.h5")

# 学習グラフのプロット
plt.figure()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['loss', 'val_loss'], loc='upper right')
plt.ylim([0,0.5])
plt.show()

#model.load_weights('./weights_189-0.03-0.03.hdf5')

def grad_cam(input_model, image, y_true, layer_name):

    y_pred = input_model.output[0]
    y_true = tf.convert_to_tensor(y_true.astype(np.float32))

    loss = K.mean(K.square(y_pred - y_true), axis=-1)
    conv_output = input_model.get_layer(layer_name).output
    grads = K.gradients(loss, conv_output)[0]

    output, grads_val =  K.function([input_model.input], [conv_output, grads])([image])
    output, grads_val = output[0, :], grads_val[0, :, :, :]
    
    weights = np.mean(grads_val, axis=(0, 1))
    cam = np.dot(output, weights)

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (300, 300), cv2.INTER_LINEAR)
#    cam = cam / cam.max()
    return cam

# 予測
X_train, Y_train = next(train_datagen.make_structural_image(batch_size=500))
X_test, Y_test = next(test_datagen.make_structural_image(batch_size=500))

# Grad-CAMの表示
Y_predict_test = model.predict(X_test)
atom = ['C', 'H', 'N', 'O', 'S', 'P', 'Si', 'Na', 'F', 'Cl', 'Br', 'I']
np.set_printoptions(precision=1, floatmode='fixed', suppress=True)

for k in range(20):

    y_true = np.array(Y_test[k])

    gradcam_bias = grad_cam(model, X_test[k:k+1], y_true, layer_name='block5_conv3')
    print(gradcam_bias.shape)

    jetcam = cv2.applyColorMap(np.uint8(255 * (gradcam_bias/gradcam_bias.max()) ), cv2.COLORMAP_JET)
    jetcam = (np.float32(jetcam) + X_test[k]) / 2

    cv2.putText(jetcam, 'Atom=' + str(atom), (4,14), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
    cv2.putText(jetcam, 'Y_pred=' + str(Y_predict_test[k]), (4,28), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
    cv2.putText(jetcam, 'Y_true=' + str(y_true),         (4,42), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
    cv2.putText(jetcam, 'loss=' + 'K.mean(K.square(y_pred - y_true), axis=-1)',         (4,288), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
    cv2.imwrite('gradcam_test%02d_bias.jpg' % k, np.uint8(jetcam))

    for i in range(5):

        y_true = np.array(Y_test[k])
        y_true[i] = 0

        gradcam = grad_cam(model, X_test[k:k+1], y_true, layer_name='block5_conv3')
        gradcam = gradcam - gradcam_bias
        gradcam = np.maximum(gradcam, 0)
        gradcam = gradcam / gradcam.max()
        print(gradcam.shape)

        jetcam = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET)
        jetcam = (np.float32(jetcam) + X_test[k]) / 2

        cv2.putText(jetcam, 'Atom=' + str(atom), (4,14), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
        cv2.putText(jetcam, 'Y_pred=' + str(Y_predict_test[k]), (4,28), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
        cv2.putText(jetcam, 'Y_true=' + str(y_true),         (4,42), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
        cv2.putText(jetcam, 'loss=' + 'K.mean(K.square(y_pred - y_true), axis=-1)',         (4,288), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1, cv2.LINE_AA)
        cv2.imwrite('gradcam_test%02d_%d.jpg' % (k,i), np.uint8(jetcam))


Y_predict_train = model.predict(X_train)
Y_predict_test = model.predict(X_test)

# 予測のプロット
atom = ['C', 'H', 'N', 'O', 'S', 'P', 'F', 'Cl']
for i in range(len(atom)):
    plt.figure()
    plt.scatter(Y_train[:,i], Y_predict_train[:,i], label = 'Train', c = 'blue')
    plt.title('Neural Network Predictor')
    plt.xlabel('Real ' + atom[i] + ' Number')
    plt.ylabel('Predicted ' + atom[i] + ' Number')
    plt.scatter(Y_test[:,i], Y_predict_test[:,i], c = 'lightgreen', label = 'Test', alpha = 0.8)
    plt.legend(loc = 4)
    plt.show()

# 50件のテスト結果を出力
atom = ['C', 'H', 'N', 'O', 'S', 'P', 'Si', 'Na', 'F', 'Cl', 'Br', 'I']
for i in range(50):
    img = array_to_img(X_test[i])
    real = ''
    pred = ''
    for j in range(len(atom)):
        if round(Y_test[i,j]) > 0:
            real += atom[j] + str(round(Y_test[i,j]))
        if round(Y_predict_test[i,j],1) > 0.5:
            pred += atom[j] + str(round(Y_predict_test[i,j],1))
    img.save('test_result%04d_%s_%s.png' % (i, real, pred))

18
15
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?