0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

 前回は、二値分類モデルで経済成長しているかどうかでディープラーニングしました。
こちらの記事を参照。
AIチャレンジ2(経済成長率の予測)
https://qiita.com/horiivalue/items/97563e4a647d071419cc

 今回は、同じデータで多クラス分類モデルで学習させてみたいと思います。

前回との違い

 二値分類モデルと異なり、成長率を0.5%づつ8クラスに分けて分類。

for i in range(117):
    s = dataset.iloc[i,0]
    if s >= 2:
        dataset.iloc[i,0] = 2
    elif s>= 1.5:
        dataset.iloc[i,0] = 1.5
    elif s>= 1:
        dataset.iloc[i,0] = 1
    elif s>= 0.5:
        dataset.iloc[i,0] = 0.5
    elif s>= 0:
        dataset.iloc[i,0] = 0
    elif s>= -0.5:
        dataset.iloc[i,0] = -0.5
    elif s>= -1:
        dataset.iloc[i,0] = -1
    elif s>= -1.5:
        dataset.iloc[i,0] = -1.5
    else:
        dataset.iloc[i,0] = -2
dataset.head()

dataset2 = pd.get_dummies(data=dataset, columns=['keizai'])
dataset2.head()

Y = np.array(dataset2[['keizai_-2.0', 'keizai_-1.5', 'keizai_-0.5','keizai_0.0','keizai_0.5','keizai_1.0','keizai_1.5','keizai_2.0']])
X = np.array(dataset2[['kakei','export','minkan','import']])

学習と効果

 学習も以下のように変更。

# モデルの初期化
model = keras.Sequential()

# 入力層
model.add(Dense(24, activation='relu', input_shape=(4,)))
# 隠れ層
model.add(Dense(24, activation='relu'))
# 出力層
model.add(Dense(8, activation='softmax'))

# モデルの構築
model.compile(optimizer = "rmsprop", loss='binary_crossentropy', metrics=['accuracy'])

%%time
# 学習の実施
log = model.fit(X_train, Y_train, epochs=5000, batch_size=32, verbose=True,
                callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',
                                                         min_delta=0, patience=100,
                                                         verbose=1)],
         validation_data=(X_valid, Y_valid))

# predictで予測を行う
Y_pred_0 = model.predict(X_test)

# 多クラス分類は予測結果の各リストにおける
# 最大値のインデックスを取得するようにする
Y_pred = np.argmax(Y_pred_0, axis=1)

# モデルの評価
from sklearn.metrics import classification_report

print(classification_report(Y_test_, Y_pred))

image.png

結果、前回よりかなり悪い値が出てきました。もう少しデータが必要のようです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?