ディープラーニングでMixture of Experts
Mixture of Expertsは複数のニューラルネットワークを組み合わせる手法のひとつで、階層型ネットワーク構造をとります。
Mixture of Expertsはゲート・ネットワークとエキスパート・ネットワークで構成され、ゲート・ネットワークがエキスパート・ネットワークの重要性を示し、エキスパート・ネットワークが判定をします。
具体的なダイアグラムは以下のようになっています。
ゲート・ネットワーク、エキスパート・ネットワークはいずれも同じ入力データを取ります。
エキスパート・ネットワークは入力データに対してターゲット変数を推論します。
(画像分類であれば、猫画像が入力されたら猫を推論する)
ゲート・ネットワークは入力データに対して、どのエキスパート・ネットワークが正しい推論をしそうか、エキスパート・ネットワークを取捨選択するための推論をします。
(猫画像を入力されたら、どのエキスパート・ネットワークが正しく猫と推論するか、推論する)
以下の図では、1つのゲート・ネットワークと4つのエキスパート・ネットワークで構成されたMixture of Expertsを考えます。
エキスパート・ネットワークは「猫犬分類が得意なエキスパート」、「哺乳類分類が得意なエキスパート」、「鳥類分類が得意なエキスパート」、「魚類分類が得意なエキスパート」が用意されています。
ゲート・ネットワークでは入力画像を大まかに猫っぽいと推論し、猫関連の分類が得意な「猫犬分類が得意なエキスパート」と「哺乳類分類が得意なエキスパート」を重要視し、他のエキスパートは低評価して、画像分類を行います。
この例ではゲート・ネットワークは全エキスパート・ネットワークに推論させていますが、Mixture of Expertsはヒエラルキー構造(ツリー構造)を取ることで、推論を行うエキスパート・ネットワークを選択することも可能です。
以下の図では3つのゲート・ネットワークと4つのエキスパート・ネットワークで構成されたMixture of Expertsを考えます。
最初の親ゲート・ネットワークが入力画像を哺乳類っぽいと推論し、哺乳類の分類を行う「哺乳類ゲート・ネットワーク」にルーティングします。
哺乳類ゲート・ネットワークは猫犬分類の得意なエキスパートと哺乳類分類が得意なエキスパートに推論を依頼し、他のエキスパートは使いません。
Mixture of Expertsの利点は多様なエキスパート・ネットワークを組み合わせることで、単一ネットワークよりも汎用的で高度な推論ができるようになることです。
全種類の分類が可能な単一モデルが用意できなくても、複数のエキスパート・ネットワークに分割し組み合わせることが可能になります。
また、特定のエキスパート・ネットワークだけ切り出して学習、再学習し、パフォーマンスを向上させることも可能です。
やったこと
今回はKerasでCifar10の分類にMixture of Expertsを実装してみました。
コード全体は以下にあります。
https://github.com/shibuiwilliam/mixture_of_experts_keras
アーキテクチャ
全体像は以下になります。
最初のゲート・ネットワークで人工物、自然物の判定をします。
ここで人工物と判定されれば人工物判定のゲート・ネットワークへ進み、自然物と判定されれば自然物判定のゲート・ネットワークへ進みます。
エキスパート・ネットワークは3種類用意しています。
- Base VGG:10種類の画像を判定可能なベーシカルなCifar10分類器
- Artificial VGG:人工物(ラベル番号:0,1,8,9)のみ学習した分類器
- Natural VGG:自然物(ラベル:2,3,4,5,6,7)のみ学習した分類器
2層目のゲート・ネットワークでは、人工物判定であればBase VGGとArtificial VGGで分類を行い、自然物判定であればBase VGGとNatural VGGで分類を行います。
2層目のゲート・ネットワークは各モデルの重要度もsoftmaxで確率を判定し、その確率をBase VGGの出力(softmaxで10種類の確率)と各エキスパート・ネットワークの出力に乗算し、合計値をMixture of Expertsの出力とします。
Mixture of Expertsではエキスパート・ネットワークを単独で事前に学習を済ましておき、Mixture of Experts自体の学習時は推論だけを行います。
つまり、Mixture of Expertsの学習中はエキスパート・ネットワークの学習(バックプロパゲーション)は行わず、フィード・フォワードだけを実行します。
Mixture of Expertsの学習ではゲート・ネットワークの学習(フォワードプロパゲーション、バックプロパゲーション)を行い、エキスパート・ネットワークへのパスを学習します。
ただし、今回は最初のゲート・ネットワーク(人工物・自然物判定)も事前学習を済ませておきました。
人工物・自然物ゲート・ネットワークは2値分類でルート判定を行うだけであるため、他のゲート・ネットワークのようにエキスパート・ネットワークの重要性を学習しないため、単独で学習することが可能です。
人工物・自然物ゲート・ネットワーク
人工物・自然物ゲート・ネットワークはVGGで2値判定を行います。
0: 人工物
1: 自然物
学習データはCifar10の画像のうちラベル番号が0,1,8,9のものは人工物(0)、ラベル番号が2,3,4,5,6,7のものは自然物(1)としたものを用意し、おおまかな人工物・自然物判定だけを行います。
人工物と判定すれば人工物ゲート・ネットワークへ、自然物と判定すれば自然物ゲート・ネットワークへパスします。
# simple VGG-like model for the first and gating neural networks
def simpleVGG(cifarInput, num_classes, name="vgg"):
name = [name+str(i) for i in range(12)]
# convolution and max pooling layers
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[0])(cifarInput)
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[1])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[2])(vgg)
vgg = Dropout(0.25, name=name[3])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[4])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[5])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[6])(vgg)
vgg = Dropout(0.25, name=name[7])(vgg)
# classification layers
vgg = Flatten(name=name[8])(vgg)
vgg = Dense(512, activation='relu', name=name[9])(vgg)
vgg = Dropout(0.5, name=name[10])(vgg)
vgg = Dense(num_classes, activation='softmax', name=name[11])(vgg)
return vgg
# first gating network, to decide artificial or natural object
gate0VGG = simpleVGG(cifarInput, gate0_classes, "gate0")
gate0Model = Model(cifarInput, gate0VGG)
# prepare dataset for the first gating network
# generate target labels with 0 or 1, with 0 for artificial object and 1 for natural object
y_trainG0 = np.array([0 if i in [0,1,8,9] else 1 for i in y_train])
y_testG0 = np.array([0 if i in [0,1,8,9] else 1 for i in y_test])
y_trainG0 = keras.utils.to_categorical(y_trainG0, 2)
y_testG0 = keras.utils.to_categorical(y_testG0, 2)
エキスパート・ネットワーク
3種類のエキスパート・ネットワークを用意しています。
Base VGGはCifar10の全学習データで10種類の物体分類が可能なように学習しています。
# simple VGG-like model for the first and gating neural networks
def simpleVGG(cifarInput, num_classes, name="vgg"):
name = [name+str(i) for i in range(12)]
# convolution and max pooling layers
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[0])(cifarInput)
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[1])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[2])(vgg)
vgg = Dropout(0.25, name=name[3])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[4])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[5])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[6])(vgg)
vgg = Dropout(0.25, name=name[7])(vgg)
# classification layers
vgg = Flatten(name=name[8])(vgg)
vgg = Dense(512, activation='relu', name=name[9])(vgg)
vgg = Dropout(0.5, name=name[10])(vgg)
vgg = Dense(num_classes, activation='softmax', name=name[11])(vgg)
return vgg
# base VGG
baseVGG = simpleVGG(cifarInput, orig_classes, "base")
baseModel = Model(cifarInput, baseVGG)
人工物エキスパート・ネットワークと自然物エキスパート・ネットワークはそれぞれのデータのみを学習しています。
つまり、人工物エキスパート・ネットワークはラベル番号0,1,8,9のデータのみを学習し、自然物エキスパート・ネットワークはラベル番号2,3,4,5,6,7のデータのみを学習しています。
それぞれがそれぞれの専門とするものだけを判定できるようにしています。
# fat and long VGG-like model for expert neural networks
def fatVGG(cifarInput, num_classes, name="vgg"):
name = [name+str(i) for i in range(17)]
# convolution and max pooling layers
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[0])(cifarInput)
vgg = Conv2D(32, (3, 3), padding='same', activation='relu', name=name[1])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[2])(vgg)
vgg = Dropout(0.25, name=name[3])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[4])(vgg)
vgg = Conv2D(64, (3, 3), padding='same', activation='relu', name=name[5])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[6])(vgg)
vgg = Dropout(0.25, name=name[7])(vgg)
vgg = Conv2D(128, (3, 3), padding='same', activation='relu', name=name[8])(vgg)
vgg = Conv2D(128, (3, 3), padding='same', activation='relu', name=name[9])(vgg)
vgg = Conv2D(128, (3, 3), padding='same', activation='relu', name=name[10])(vgg)
vgg = MaxPooling2D(pool_size=(2,2), name=name[11])(vgg)
vgg = Dropout(0.25, name=name[12])(vgg)
# classification layers
vgg = Flatten(name=name[13])(vgg)
vgg = Dense(512, activation='relu', name=name[14])(vgg)
vgg = Dropout(0.5, name=name[15])(vgg)
vgg = Dense(num_classes, activation='softmax', name=name[16])(vgg)
return vgg
# artificial expert VGG
artificialVGG = fatVGG(cifarInput, orig_classes, "artificial")
artificialModel = Model(cifarInput, artificialVGG)
# natural expert VGG
naturalVGG = fatVGG(cifarInput, orig_classes, "natural")
naturalModel = Model(cifarInput, naturalVGG)
# get the position of artificial images and natural images in training and test dataset
artTrain = [i for i in range(len(y_train)) if y_train[i] in [0,1,8,9]]
natureTrain = [i for i in range(len(y_train)) if y_train[i] in [2,3,4,5,6,7]]
artTest = [i for i in range(len(y_test)) if y_test[i] in [0,1,8,9]]
natureTest = [i for i in range(len(y_test)) if y_test[i] in [2,3,4,5,6,7]]
# get artificial dataset and natural dataset
x_trainArt = x_train[artTrain]
x_testArt = x_test[artTest]
y_trainArt = y_train[artTrain]
y_testArt = y_test[artTest]
# for artificial dataset
y_trainArt = keras.utils.to_categorical(y_trainArt, orig_classes)
y_testArt = keras.utils.to_categorical(y_testArt, orig_classes)
# for natural dataset
x_trainNat = x_train[natureTrain]
x_testNat = x_test[natureTest]
y_trainNat = y_train[natureTrain]
y_testNat = y_test[natureTest]
# get natural dataset
y_trainNat = keras.utils.to_categorical(y_trainNat, orig_classes)
y_testNat = keras.utils.to_categorical(y_testNat, orig_classes)
ここまでに作ってきた人工物・自然物ゲート・ネットワークと各エキスパート・ネットワークはMixture of Expertsでは学習を行いません。
というわけで、各モデルのレイヤーを学習不可に設定します。
for l in baseModel.layers:
l.trainable = False
for l in gate0Model.layers:
l.trainable = False
for l in artificialModel.layers:
l.trainable = False
for l in naturalModel.layers:
l.trainable = False
人工物ゲート・ネットワーク、自然物ゲート・ネットワーク
人工物ゲート・ネットワーク、自然物ゲート・ネットワークは各エキスパート・ネットワークの重要度を推論し、出力に重み付けする役割を担います。
ゲート・ネットワークの出力する重要度(softmaxで推論された各エキスパート・ネットワークの出力への重み)をエキスパート・ネットワークの出力に乗じて、その和がMixture of Experts全体の出力(分類結果)になります。
Kerasで各ネットワークの出力を四則演算するためにはLambdaレイヤーを使い、四則演算自体を1レイヤーとします。
https://qiita.com/Mco7777/items/158296ed7f66aed2ffc3
Mixture of Experts全体では、人工物・自然物ゲート・ネットワークが推論した後、人工物ゲート・ネットワーク"または"自然物ゲート・ネットワークへルーティングします。
AndではなくOrです。
Kerasでケース判定をする場合はKerasバックエンドのSwitchを使います。
Switchの使い方は以下のようになります。
import keras.backend as K
K.switch(条件<boolean>, Trueの場合の処理, Falseの場合の処理)
KerasバックエンドのSwitchを使い、人工物・自然物ゲート・ネットワークのSoftmax出力で2値のうち高いほうに進みます。
# define sub-Gate network, for the second gating network layer
def subGate(cifarInput, orig_classes, numExperts, name="subGate"):
name = [name+str(i) for i in range(5)]
subgate = Flatten(name=name[0])(cifarInput)
subgate = Dense(512, activation='relu', name=name[1])(subgate)
subgate = Dropout(0.5, name=name[2])(subgate)
subgate = Dense(orig_classes*numExperts, activation='softmax', name=name[3])(subgate)
subgate = Reshape((orig_classes, numExperts), name=name[4])(subgate)
return subgate
# the artificial gating network
artGate = subGate(cifarInput, orig_classes, 2, "artExpertGate")
# the natural gating network
natureGate = subGate(cifarInput, orig_classes, 2, "natureExpertGate")
# define inference calculation with Keras Lambda layer with base VGG, expert network and the second gating network of corresponding expert as input
# the inference is calculated as sum of multiplications of base VGG inference output and its importance, and expert network inference output and its importance
def subGateLambda(base, expert, subgate):
output = Lambda(lambda gx: (gx[0]*gx[2][:,:,0]) + (gx[1]*gx[2][:,:,1]), output_shape=(orig_classes,))([base, expert, subgate])
return output
# connecting the overall networks.
# the Keras backend switch works as deciding with the first gating network, leading to artificial or natural gate
output = Lambda(lambda gx: K.switch(gx[1][:,0] > gx[1][:,1],
subGateLambda(gx[0], gx[2], gx[4]),
subGateLambda(gx[0], gx[3], gx[5])),
output_shape=(orig_classes,))([baseVGG, gate0VGG, artificialVGG, naturalVGG, artGate, natureGate])
# the mixture of experts model
model = Model(cifarInput, output)
これでMixture of Expertsネットワーク全体が定義され、各ゲート・ネットワークとエキスパート・ネットワークがつながりました。
全体像は以下のようになります。
可読性を損なうデカさです。
評価
Mixture of Expertsの学習では2層目のゲート・ネットワーク(人工物ゲート・ネットワークと自然物ゲート・ネットワーク)のみ学習します。
Mixture of Expertsの評価は以下のとおりです。
うまく動いているか確認
ゲート・ネットワークのルーティングがちゃんと動いているか(switchしてるか)、計算してみましょう。
テストデータの0番目はラベル番号3の猫画像なのですが、Mixture of Expertsの出力値が、各ゲート・ネットワーク、エキスパート・ネットワークを合わせた出力値と同じかどうか確認します。
猫画像なので、最初の人工物・自然物ゲート・ネットワークは自然物ゲート・ネットワークへ誘導します。
自然物ゲート・ネットワークはBase VGGと自然物エキスパート・ネットワークの重要度を算出し、Base VGGと自然物エキスパート・ネットワークは画像のラベルを判定します。
最後に自然物エキスパート・ネットワークの重要度をBase VGG、自然物エキスパート・ネットワークの出力と乗算し合計したものが判定結果になります。
想定では以下のようにゲート・ネットワークとエキスパート・ネットワークを通っているはずです。
KerasではKerasバックエンドのfunctionを使うことでネットワークの途中を切り出して使うことができます。
K.functionを使い、Mixture of Expertsから人工物・自然物ゲート・ネットワークの出力、自然物ゲート・ネットワークの出力、Base VGGの出力、自然物エキスパート・ネットワークの出力を切り出します。
(該当するレイヤー番号を探して指定しなければならないのが玉にキズですが・・・)
# the first y_test data is labeled 3, so it should be natural object
print(y_test0[0])
# array([ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
# Keras function to get the BaseVGG output
getBase = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[63].output])
# Keras function to get the first gating network output
getGate0 = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[64].output])
# Keras function to get the artificial expert network output
getArt = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[65].output])
# Keras function to get the natural expert network output
getNature = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[66].output])
# Keras function to get the output for artificial gating network
getArtGate = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[67].output])
# Keras function to get the output for natural gating network
getNatureGate = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[68].output])
# output from the first gating network for x_test[0]
# since the second index is higher than the first, it is classified as a natural object
# the first gating should lead the inference to the natural gating network
xtest0Gate = getGate0([[x_test[0]], 0])[0]
print(xtest0Gate)
# array([[ 0.00687723, 0.99312276]], dtype=float32)
# output from natural gating network
# the list of softmax output for the importances of base VGG and natural expert network
xtest0NatureGate = getNatureGate([[x_test[0]], 0])[0]
print(xtest0NatureGate)
# array([[[ 0.08541576, 0.03585474],
# [ 0.08810793, 0.03100363],
# [ 0.01911297, 0.03751711],
# [ 0.04628449, 0.04165833],
# [ 0.02868748, 0.06180721],
# [ 0.04721113, 0.07710078],
# [ 0.02533226, 0.0840624 ],
# [ 0.04218154, 0.05244576],
# [ 0.08518016, 0.02692058],
# [ 0.07291996, 0.01119581]]], dtype=float32)
# inferenct from base VGG for x_test[0]
xtest0Base = getBase([[x_test[0]], 0])[0]
print(xtest0Base)
# array([[ 1.30025754e-02, 4.88831941e-03, 6.07487699e-03,
# 6.73251569e-01, 3.08401213e-04, 2.47065753e-01,
# 2.06855088e-02, 2.67533562e-03, 2.77825613e-02,
# 4.26515006e-03]], dtype=float32)
# inference from natural expert network for x_test[0]
xtest0Nature = getNature([[x_test[0]], 0])[0]
print(xtest0Nature)
# array([[ 2.08449508e-11, 4.72127615e-11, 1.22241955e-02,
# 6.25748694e-01, 7.82116503e-03, 3.36175740e-01,
# 6.49287738e-03, 1.15373293e-02, 2.38116731e-11,
# 2.11326633e-11]], dtype=float32)
# multiply the base VGG inference with the natural gating inference for the importance of base VGG
# multiply the natural expert inference with the natural gating inference for the importance of natural expert
# get the sum
(xtest0Base * xtest0NatureGate[:,:,0]) + (xtest0Nature * xtest0NatureGate[:,:,1])
# array([[ 0.00111062, 0.0004307 , 0.00057473, 0.05722876, 0.00049225,
# 0.03758366, 0.00106982, 0.00071793, 0.00236652, 0.00031101]], dtype=float32)
# the mixture of experts prediction for x_test[0]
# showing exactly the same inferences!
model.predict(x_test[:1])
# array([[ 0.00111062, 0.0004307 , 0.00057473, 0.05722876, 0.00049225,
# 0.03758366, 0.00106982, 0.00071793, 0.00236652, 0.00031101]], dtype=float32)
Mixture of Expertsの推論結果とネットワークを切り出した計算結果が合致しているようなので、正常にswitchできているようです。
最後に
今回はサーバ1台で作りましたが、ゲート・ネットワークとエキスパート・ネットワークを専用のGPUやサーバに設置するというアーキテクチャも考えられているようです。
そのうち気が向いたらKubernetesと組み合わせて、各ネットワークを専用コンテナに分散してみたいと思いますが、予定は未定です。
あと、Noisy top-K gating softmaxでゲート・ネットワークの分岐を有機的にしたいですが、予定は未定です。
参考
https://www.cs.toronto.edu/~hinton/csc321/notes/lec15.pdf
https://arxiv.org/pdf/1701.06538.pdf
https://github.com/krishnakalyan3/MixtureOfExperts
http://www.cis.twcu.ac.jp/~asakawa/waseda2002/ME.pdf
https://people.cs.pitt.edu/~milos/courses/cs2750-Spring04/lectures/class22.pdf
http://yamaimo.hatenablog.jp/entry/2016/02/29/200000