複数のカテゴライズを一つのnnで行いたい
(複数出力のやり方は、たくさんあるので書きません)
ファッションサーチファンネル(ファッション画像検索エンジン)
https://funnel-service.com
のセラーデータベースを構築するために商品紹介文から売っている服の分類をしている。
男女の判別も行いたいので、複数出力でやってみた。
モデルの構築コードはこんな感じ
cat_inputs = tf.keras.Input(shape=(max_words,))
sentence_inputs = tf.keras.Input(shape=(max_words,))
# 結合
c = tf.keras.layers.concatenate([cat_inputs,sentence_inputs ])
x2 = tf.keras.layers.BatchNormalization()(c)
x2 = tf.keras.layers.Dense(128, activation='relu')(x2)
x2 = tf.keras.layers.BatchNormalization()(x2)
x2 = tf.keras.layers.Dense(128, activation='relu')(x2)
x2 = tf.keras.layers.Dense(3, activation='sigmoid',name="gend")(x2)
x = tf.keras.layers.BatchNormalization()(c)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dense(num_classes, activation='sigmoid',name='class')(x)
model = tf.keras.Model(inputs=[cat_inputs,sentence_inputs], outputs=[x,x2])
↑二つのval線(緑と赤)のスイートスポットがずれている。こりゃだめだ。。。別々に学習してみる。
↑商品分類だけで学習。
↑性別分類だけで学習。
結果
スイートスポットがとりやすいので別々のモデルとして学習したほうがいい。
他のケースではわからないけど、今回の複数出力は効率が悪いことが分かった。