Help us understand the problem. What is going on with this article?

ニューラルネットワークでプロ野球選手の給与を査定してみる

More than 1 year has passed since last update.

概要

TensorFlowでニューラルネットワークを使い、94人のプロ野球投手の年間の成績から年俸を推定してみます。
訓練データとして89人の成績と年俸を使い、残りの5人の選手の年俸をどれだけ精確に推定できるかを検証します。

注:この記事は選手の年俸についての意見を述べるものではなく、検証の結果はいかなる選手の年俸の不当性を訴えるものでもありません。

入力

以下の33種類のデータを入力として取り扱います。

  • 球団(12個のOne-Hot Vector)
  • 防御率
  • 出場試合数
  • 勝利数
  • 敗北数
  • セーブ
  • ホールド
  • 勝率
  • 打者
  • 投球回数
  • 被安打
  • 被本塁打
  • 与四球
  • 与死球
  • 奪三振
  • 失点
  • 自責点
  • WHIP
  • DIPS
  • 所属年数
  • 年齢
  • 国内選手 or 国外選手

各データは最小0、最大1の値を取るように正規化しました。
94人の選手のうち89人のデータを訓練データとして、残りの5人の選手のデータをテストデータとして扱います。

ネットワーク構造

以下の隠れ層1層のネットワークを使用しました。
選手毎の給与のベクトルはL2Normalizeをした上で比較しています。
今回は勉強のために、隠れ層のノード数Nが4の場合と32の場合を比べてみます。

network.png

誤差関数

ネットワークの出力及び教師となる給与ベクトルをL2Normalizeした上で、両者のL2距離を最小化します。
給与はとあるウェブサイトに掲載されていた「推定年俸」を使用しましたが、選手の特定を避けるためにリンクの掲載は控えさせていただきます。

学習結果

縦軸に誤差、横軸に反復回数をプロットしたグラフを訓練データ及びテストデータについて示します。

隠れ層のノード数32の場合

npb_32.png

隠れ層のノード数4の場合

npb_4.png

最も誤差の少なかった推定

salary_comparison.png

(隠れ層ノード数4の最小誤差のケース)

考察

隠れ層ノード数の影響

隠れ層ノード数が多い方が訓練データに対する誤差が少なくなる一方で、テストデータに対する誤差は増えます。これはOver-Fittingの効果で、ネットワークが表現できる自由度が高すぎるために、訓練データに対して適応しすぎるあまりテストデータに対しては適当な「補完」ができないからであると考えられます。

過学習の影響

隠れ層ノード数に関係なく、学習の反復回数を増やし過ぎるとテストデータに対する誤差が増える「過学習」が生じています。
これも一種のOver-Fittingで、学習の早期打ち切りをしたほうが望ましい結果が得られることがわかります(詳細は参考文献の「深層学習」参照)

最適なネットワーク構造

今回は隠れ層ノード数4の場合と32の場合を比較しましたが、結局隠れ層ノード数を幾つにすれば最適な結果が得られるかは「やってみないとわからない」です。
同様に、学習回数を何回で打ち切れば最適な結果が得られるかも、「やってみないとわからない」です。
この辺のハイパーパラメータを試行錯誤しながら調整しないといけないのがニューラルネットワークの難しさだと感じました。

まとめ

給与に直結して数値化できる成果がある職業なら、給与査定を機械に任せても良くなる時代がもうすぐ来そうですね。
ヒトの主観で判断するよりもよっぽど信頼できると思います。

参考

コード

import tensorflow as tf
import numpy

SCORE_SIZE = 33
HIDDEN_UNIT_SIZE = 32
TRAIN_DATA_SIZE = 90

raw_input = numpy.loadtxt(open("input.csv"), delimiter=",")
[salary, score]  = numpy.hsplit(raw_input, [1])

[salary_train, salary_test] = numpy.vsplit(salary, [TRAIN_DATA_SIZE])
[score_train, score_test] = numpy.vsplit(score, [TRAIN_DATA_SIZE])

def inference(score_placeholder):
  with tf.name_scope('hidden1') as scope:
    hidden1_weight = tf.Variable(tf.truncated_normal([SCORE_SIZE, HIDDEN_UNIT_SIZE], stddev=0.1), name="hidden1_weight")
    hidden1_bias = tf.Variable(tf.constant(0.1, shape=[HIDDEN_UNIT_SIZE]), name="hidden1_bias")
    hidden1_output = tf.nn.relu(tf.matmul(score_placeholder, hidden1_weight) + hidden1_bias)
  with tf.name_scope('output') as scope:
    output_weight = tf.Variable(tf.truncated_normal([HIDDEN_UNIT_SIZE, 1], stddev=0.1), name="output_weight")
    output_bias = tf.Variable(tf.constant(0.1, shape=[1]), name="output_bias")
    output = tf.matmul(hidden1_output, output_weight) + output_bias
  return tf.nn.l2_normalize(output, 0)

def loss(output, salary_placeholder, loss_label_placeholder):
  with tf.name_scope('loss') as scope:
    loss = tf.nn.l2_loss(output - tf.nn.l2_normalize(salary_placeholder, 0))
    tf.scalar_summary(loss_label_placeholder, loss)
  return loss

def training(loss):
  with tf.name_scope('training') as scope:
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
  return train_step


with tf.Graph().as_default():
  salary_placeholder = tf.placeholder("float", [None, 1], name="salary_placeholder")
  score_placeholder = tf.placeholder("float", [None, SCORE_SIZE], name="score_placeholder")
  loss_label_placeholder = tf.placeholder("string", name="loss_label_placeholder")

  feed_dict_train={
    salary_placeholder: salary_train,
    score_placeholder: score_train,
    loss_label_placeholder: "loss_train"
  }

  feed_dict_test={
    salary_placeholder: salary_test,
    score_placeholder: score_test,
    loss_label_placeholder: "loss_test"
  }

  output = inference(score_placeholder)
  loss = loss(output, salary_placeholder, loss_label_placeholder)
  training_op = training(loss)

  summary_op = tf.merge_all_summaries()

  init = tf.initialize_all_variables()

  best_loss = float("inf")

  with tf.Session() as sess:
    summary_writer = tf.train.SummaryWriter('data', graph_def=sess.graph_def)
    sess.run(init)

    for step in range(10000):
      sess.run(training_op, feed_dict=feed_dict_train)
      loss_test = sess.run(loss, feed_dict=feed_dict_test)
      if loss_test < best_loss:
        best_loss = loss_test
        best_match = sess.run(output, feed_dict=feed_dict_test)
      if step % 100 == 0:
        summary_str = sess.run(summary_op, feed_dict=feed_dict_test)
        summary_str += sess.run(summary_op, feed_dict=feed_dict_train)
        summary_writer.add_summary(summary_str, step)

    print sess.run(tf.nn.l2_normalize(salary_placeholder, 0), feed_dict=feed_dict_test)
    print best_match

追記(2016/12/20)

コメントでもご指摘いただきましたが、このプログラムだけでは機械学習モデルの汎化性能を担保したことにはなりません。
具体的には、

  • 特定の5選手を訓練データから取り除くのではなくクロスバリデーション等のデータセット分割方法によるバイアスを防ぐ手法を取り入れる
  • 反復回数や隠れ層のノード数等のハイパーパラメータを決定するための validation data set を導入し、最終的な性能評価には学習時の最良の解(プログラム中の best_matchではなく、決定された反復回数でモデルを訓練したあとに、 test_data に対する inference を一度だけする

必要があると思います。

この記事を書いた当時は「TensorFlowを使ってみた!」程度のつもりで書いたのですが紛らわしい書き方で申し訳ないです。
「この通りやれば良いモデルが作れる」という性質のものではないのでご了承ください(試していませんが精度を求めるならニューラルネットワーク以外のモデルのほうがむしろ良いこともあるかもしれません)。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away