Help us understand the problem. What is going on with this article?

Deep Learningで集合を扱う方法

More than 1 year has passed since last update.

入力としての集合

例えば複数枚の画像を入力して、何か推論させたいとしましょう。
色々方法はあるのですが、LSTMを使ってエンコードしたり、画像をつなげたりして入力する場合は、画像を入れる順番が推論に影響を与えることにならないように、色々と並び替えをする必要があります。

色々な方法についてはこちら

できればネットワークの構造の中に、画像を入れる順番に対して動作が不変となるようなアーキテクチャを取り入れたいはずです。
今回取り上げるDLはそんなアーキテクチャを取り入れたネットワークになります。

Deep Sets

つ「論文

とても難しいですが、この論文をサーベイしてまとめてくれている資料がありました。

https://www.slideshare.net/TomohiroTakahashi2/deep-sets-86057010

要は、閉区間[0, 1]上の任意の連続関数$f(x)$は、$x$の多項式で表せるとかいうWeierstrassの多項式近似定理が本論文の理論を支えています。

f(x) \fallingdotseq a_0 + a_1x + a_2x^2 + a_3x^3 + ...

難しいことはわからないのでとりあえず盲目的に信じてみます。
テイラー展開みたいですね。

今回は集合を扱うので$x_1, x_2, x_3 ...$と複数個あるということになります。
この$x$を入れる順番に対して不変性を持たせる場合には、多項式に対称性を導入すればよいということになります。

多項式の対称性

$x_1, x_2, x_3 ...$に関して、$e_i$を以下のように定義します。下の例では3つまで書いてますが、集合の数によって書けることのできる式に制限があったり、またはどこかで打ち切ったりすることもあります。

\begin{array}{l}
e_1 = \sum_{i} x_i & \\
e_2 = \sum_{i, j} x_i x_j  & (i \neq j) \\
e_3 = \sum_{i, j, k} x_i x_j x_k & (i \neq j \neq k) \\
\end{array}

この$e_i$はそれぞれ、$x_i$の並びに対して値が変わることがありません。したがって、順番に対する不変性が保証されます。

最後に$e_i$を以下のようにしてまとめます。

f(x_1, x_2, x_3, ...) \fallingdotseq \sum_{m}c_m e_m

$c_m$はそれぞれの$e_m$に対する重みです。ハイパーパラメータとして設定しても良いですし、学習で調節するパラメータとして扱っても良いと思います。
ここら辺については論文をちゃんと読んでいないので、何とも言えないです。

実装

MNISTで実装し、評価します。

タスク

複数の画像を受け取り、そこに書かれた数字の和を推定する。
なお、実装を簡単化するため、入力として受け取る集合の最大要素数は3とする。

処理の流れ

  1. 複数の画像を畳み込み層に入力し、特徴ベクトルを得る。なお、各画像に対して、畳み込み層の重みは共有される。
  2. 得られた複数の特徴ベクトルに対して前の節で述べたように対称性をもつ多項式の計算を行い、1つの特徴ベクトルに統合する
  3. 全結合層に入力し、推定値を得る

ネットワークの外観

image.png

いたってシンプルな構造です。
これをmodel_1 : Modelとして、複数の画像を入力として1つの出力を吐き出すネットワーク(Deep Sets)は以下のようになります。

image.png

今回は集合の要素数は3つなので、Inputが3つあります。
要素数が2のときは、余った1つにはゼロを入力すればよいです。

Jupyterで書くとどうなるのか?

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from tensorflow.python.keras.layers import Input, Flatten, Dense, Activation, Lambda
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import plot_model

%matplotlib inline
plt.gray()

MNISTのデータを読み込んで、0から1までの値に正規化します。

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[:, :, :, np.newaxis].astype(np.float32) / 255.
X_test = X_test[:, :, :, np.newaxis].astype(np.float32) / 255.

画像から特徴ベクトルを抽出するレイヤーです。
タスクが難しい場合はVGGとかResNetとかで置き換えることもできます。

def build_extract_layer():
    _input = Input(shape=(28, 28, 1))

    x = Conv2D(8, (3, 3), padding='same')(_input)
    x = MaxPooling2D((2, 2))(x)
    x = Activation('relu')(x)

    x = Conv2D(16, (3, 3), padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Activation('relu')(x)

    x = Conv2D(64, (3, 3), padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Activation('sigmoid')(x)

    _output = Flatten()(x)

    return Model(_input, _output)

対称性をもつ多項式を記述します。
集合の要素数を3つと決めているので、以下のような実装にしています。
今回は$c_m$はすべて1としました。考えるのが面倒くさかった

def set_func(args):
    _vec_1, _vec_2, _vec_3 = args
    _e_1 = _vec_1 + _vec_2 + _vec_3
    _e_2 = _vec_1 * _vec_2 + _vec_2 * _vec_3 + _vec_3 * _vec_1
    return _e_1 + _e_2

モデルのインスタンスを生成し、損失関数を定義します。
推定値と正解値の絶対誤差をとるmean_absolute_errorがおすすめです。

_input_1 = Input(shape=(28, 28, 1))
_input_2 = Input(shape=(28, 28, 1))
_input_3 = Input(shape=(28, 28, 1))

extract_model = build_extract_layer()

_vec_1 = extract_model(_input_1)
_vec_2 = extract_model(_input_2)
_vec_3 = extract_model(_input_3)

_set = Lambda(set_func, output_shape=(576,))([_vec_1, _vec_2, _vec_3])

_predict = Dense(1)(_set)

model = Model([_input_1, _input_2, _input_3], _predict)
opt = SGD(lr=0.001, momentum=0.9, decay=1e-6, nesterov=True)
model.compile(
    optimizer=opt,
    loss='mean_absolute_error'
)
model.summary()

データセットを作成する処理です。

def create_dataset(X, y, n_new_data=10000):
    n_original_data = X.shape[0]
    new_X = [np.zeros((n_new_data, 28, 28, 1)), np.zeros((n_new_data, 28, 28, 1)), np.zeros((n_new_data, 28, 28, 1))]
    new_y = np.zeros((n_new_data,))
    for i in range(n_new_data):
        _index_1, _index_2, _index_3 = np.random.choice(np.arange(n_original_data), 3)

        new_X[0][i, :, :, :] = X[_index_1, :, :, :]
        new_X[1][i, :, :, :] = X[_index_2, :, :, :]
        new_X[2][i, :, :, :] = X[_index_3, :, :, :]

        new_y[i] = float(y[_index_1] + y[_index_2] + y[_index_3])

    return new_X, new_y

実際に作成します。n_new_dataでデータの数を指定できます。

new_X_train, new_y_train = create_dataset(X_train, y_train, n_new_data=10000)
new_X_test, new_y_test = create_dataset(X_test, y_test, n_new_data=1000)

すべての準備が整ったので、実際に学習させます。
私の環境では50epoch目くらいで訓練誤差が1.0を下回っていました。

model.fit(new_X_train, new_y_train,
    validation_data=[new_X_test, new_y_test],
    epochs=100,
    batch_size=32,
    shuffle=True,
    verbose=1
)

学習が終われば、以下の関数で推論させます。

y_preds = model.predict(new_X_test)

以下のコードで結果を可視化します。

index = 50
plt.subplot(1, 3, 1)
plt.imshow(new_X_test[0][index, :, :, 0])
plt.subplot(1, 3, 2)
plt.imshow(new_X_test[1][index, :, :, 0])
plt.subplot(1, 3, 3)
plt.imshow(new_X_test[2][index, :, :, 0])
print(y_preds[index, 0], new_y_test[index])

さいごに

数学わからん。。。

仲間外れ検出とかもできるそうです。

その場合、順番に対して不変な部分と、順番をちゃんと保持する部分が必要になりますね。
どうやってるんでしょう。私はここで筆を置きます。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした