8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ニューラルネットで回帰を行い、身長+3サイズでセクシー女優のカップ数を推定させる

Last updated at Posted at 2018-04-29

はじめに

前回https://qiita.com/ruteshi_SI_shiteru/items/fe3fcf96bdd322407e1c
の続きとなります。
前回はNNで分類を行わせて推定させていたのですが、精度が悪かったため、今回は回帰で推定を行いました。

環境

前回と同じです。

使用データ

前回及び先行研究と同じです。
※先行研究…https://qiita.com/kkdmgs110/items/593b9a2a270734d06070

ちなみに、cupからcup numを生み出す加工を行い、以下のようなデータとしております。(実は前回もですが)
image.png

コード全体像

code
import pandas as pd

df = pd.read_csv('dataset.csv')

# 確認
df.head(3)

# 教師データ(cupnum列)
t = df.iloc[:, 6]

# 入力変数(3~6番目)
x = df.iloc[:, 2:6]

# 確認
t.head(3)

# 確認
x.head(3)

# Numpyにデータ型を変換
# 回帰バージョン
tn = (t.values - 1).reshape(t.size,1).astype('float32')
xn = x.values.astype('float32')

# 確認
tn.dtype

# 確認
tn

# 確認
xn.dtype

# 確認
xn

import chainer
import chainer.links as L
import chainer.functions as F

class NN(chainer.Chain):

    # モデルの構造を明示
    def __init__(self):
        super().__init__()
        with self.init_scope():
            self.l1 = L.Linear(4, 4)
            self.l2 = L.Linear(4, 4)
            self.l3 = L.Linear(4, 1)

    # 順伝播
    def __call__(self, x):
        u1 = self.l1(x)
        z1 = F.relu(u1)
        u2 = self.l2(z1)
        z2 = F.relu(u2)
        u3 = self.l3(z2)
        return u3

import numpy as np

# 再現性確保のためシードを固定
np.random.seed(1)

# NNモデルをインスタンス化
model = NN()

from chainer import optimizers

# 最適化にはAdamを使用
optimizer = optimizers.Adam()
optimizer.setup(model) # modelと紐付ける

# 損失関数の計算
#   損失関数には自乗誤差(MSE)を使用
def forward(x, y, model):
    t = model.__call__(x)
    loss = F.mean_squared_error(t, y)
    return loss


# パラメータの学習を繰り返す
results = []
for i in range(0,30000):
    loss = forward(xn, tn, model)
    print(loss.data)  # 現状のMSEを表示
    results.append(loss.data)
    optimizer.update(forward, xn, tn, model)
    
# 学習完了

# 誤差の推移をプロットしインライン表示
import matplotlib.pyplot as plt
%matplotlib inline
plt.yscale("log")
plt.xscale("log")
plt.plot(results)
plt.show()

# 予測 数値の見方:1 = Aカップ、2 = Bカップ… 5 = Eカップ… 
# Rioさん(C cup)
xq = np.array([[154,84,58,83]], 'f')
t1 = model.__call__(xq)+1
print(t1)
# variable([[ 4.11120224]])

# 蒼井そらさん(G cup)
xq = np.array([[155,90,58,83]], 'f')
t1 = model.__call__(xq)+1
print(t1)
# variable([[ 5.91459274]])

# 石原莉奈さん(C cup)
xq = np.array([[155,85,56,84]], 'f')
t1 = model.__call__(xq)+1
print(t1)
# variable([[ 4.56367683]])

# 野中あんりさん(A cup)
xq = np.array([[154,74,56,84]], 'f')
t1 = model.__call__(xq)+1
print(t1)
# variable([[ 2.42423582]])

# さくら柚木さん(K cup)
xq = np.array([[157,115,68,94]], 'f')
t1 = model.__call__(xq)+1
print(t1)
# variable([[ 12.39152241]])

手法

学習は4層ニューラルネットワークで行います。複雑なモデルのほうが適しているのか…と思い、前回の3層から1層増やしてみました。

結果

精度

誤差の推移を以下にプロットします。10000回を過ぎた頃から安定していることが伺えます。なお、今回はすべてのデータを学習対象としました。
誤差の推移

予測について

予測結果はコード内に示しておりますので、ここでの再掲は省かせていただきます。

考察

前回、「学習に回帰ではなく分類を用いたことが精度低下の主因か」と触れましたが、概ねその通りだと考えられます。

所感

予測値が意外と近い値を示してくれました。さすがに、完璧に予測することは不可能なので、これくらいが関の山かもしれません。

最後までお付き合いいただき有難うございました。

8
8
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?