LoginSignup
1
0

More than 3 years have passed since last update.

複数出力のnnを学習させたら、効率悪かった。機械学習で商品分類と対象性別判断

Last updated at Posted at 2019-07-09

複数のカテゴライズを一つのnnで行いたい

(複数出力のやり方は、たくさんあるので書きません)

ファッションサーチファンネル(ファッション画像検索エンジン)
https://funnel-service.com
のセラーデータベースを構築するために商品紹介文から売っている服の分類をしている。

男女の判別も行いたいので、複数出力でやってみた。

モデルの構築コードはこんな感じ


cat_inputs = tf.keras.Input(shape=(max_words,))
sentence_inputs = tf.keras.Input(shape=(max_words,))

# 結合

c = tf.keras.layers.concatenate([cat_inputs,sentence_inputs ])

x2 = tf.keras.layers.BatchNormalization()(c)
x2 = tf.keras.layers.Dense(128, activation='relu')(x2)
x2 = tf.keras.layers.BatchNormalization()(x2)
x2 = tf.keras.layers.Dense(128, activation='relu')(x2)
x2 = tf.keras.layers.Dense(3, activation='sigmoid',name="gend")(x2)

x = tf.keras.layers.BatchNormalization()(c)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dense(num_classes, activation='sigmoid',name='class')(x)


model = tf.keras.Model(inputs=[cat_inputs,sentence_inputs], outputs=[x,x2])


Figure_12.png

↑二つのval線(緑と赤)のスイートスポットがずれている。こりゃだめだ。。。別々に学習してみる。

11Figure_1.png

↑商品分類だけで学習。

123Figure_1.png

↑性別分類だけで学習。

結果

スイートスポットがとりやすいので別々のモデルとして学習したほうがいい。
他のケースではわからないけど、今回の複数出力は効率が悪いことが分かった。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0