はじめに
玉手箱の空欄推測問題を解いてみたのソルバー編です.前回までで問題を解く上で必要な値の取得までできているので,今回は実際に問題を解いていきます.
問題設定
空欄推測問題は以下のような問題で,テーブルの数値が二次元配列で与えられ,予測する値がNaNになっていることを想定しています.
(https://jyosiki.com/spi/tama31_a.html より引用)
運賃を予測するので,運賃を目的変数,その他の値を説明変数とした線形回帰問題として解きます.数式で書くと
y_{運賃} = w_{フェリーの本数} \times x_{フェリーの本数} + w_{乗客数} \times x_{乗客数} + w_{フェリーの乗車時間} \times x_{フェリーの乗車時間}
のような感じです.それぞれの変数の値$x$に重み$w$を掛けて,足し合わせることで$y$を求めます.ここでの$w$をA$\sim$D市のデータから求めて,E市の運賃を求めます.バイアス項を考慮する場合は$x_{バイアス}=1$を追加して考えます.
解法
重み$w$の値は最小二乗法を使って求めます.つまり$y$の値と$\sum w\times x$の値の差が最も小さくなるような$w$を求めます.線型回帰問題の場合は最適解が計算によって求められて,今回の問題であれば,
X =
\left(
\begin{array}{cccc}
x_{フェリーの本数}^A & x_{フェリーの本数}^B & x_{フェリーの本数}^C & x_{フェリーの本数}^D\\
x_{乗客数}^A & x_{乗客数}^B & x_{乗客数}^C & x_{乗客数}^D \\
x_{フェリーの乗車時間}^A & x_{フェリーの乗車時間}^B & x_{フェリーの乗車時間}^C & x_{フェリーの乗車時間}^D
\end{array}
\right)
y = \left(
\begin{array}{cccc}
y_{運賃}^A & y_{運賃}^B & y_{運賃}^C & y_{運賃}^D
\end{array}
\right)
w = \left(
\begin{array}{ccc}
w_{フェリーの本数} & w_{乗客数} & w_{フェリーの乗車時間}
\end{array}
\right)
のような行列の形で定義すると,
w = (X X^T)^{-1} X y
というような感じで求められます.(一般的なデータ分析で使うときの表現と行と列の向きが逆ですが,行列の形式を問題の表の値と対応付けるためにこの順番にしています.このあたりは自分で調整してください)
最小二乗法の最適解を求めるコードは以下のようになります.基本的には今までの話をプログラムに落とし込む感じです.まず,はじめに二次元配列のNaNの位置から,重み$w$を求めるために使うデータと予測に使うためのデータを分離しています.その後,最適解を求める数式にしたがって重みを計算します.最後に求めた重みから予測を行います.
nan_row, nan_col = self.nan_position
train_data = np.delete(table, nan_col, axis=1)
X, y = np.delete(train_data, nan_row, axis=0), train_data[nan_row]
if use_bias:
X = np.concatenate([X, np.ones((1, X.shape[1]))], axis=0)
test_data = table[:, nan_col]
test_x = np.delete(test_data, nan_row, axis=0)
if use_bias:
test_x = np.concatenate([test_x, [1]])
w = np.linalg.inv(X @ X.T) @ X @ y
result = test_x @ w
結果
上の問題をバイアス項なしで解くと結果は499.63になります.
コード全体
最終的な実装内容としては以下のようになりました.値入力の計算機が正しいかの判定も行っているので,それらの処理全体をクラスとしてまとめました.
import numpy as np
class Solver:
def __init__(self, frontend):
self.frontend = frontend
def solve(self):
table = self.frontend.table
table = np.array(table)
use_bias = self.frontend.use_bias.get()
if not self.check_table_format(table):
self.frontend.set_result_str(self.message)
else:
nan_row, nan_col = self.nan_position
train_data = np.delete(table, nan_col, axis=1)
X, y = np.delete(train_data, nan_row, axis=0), train_data[nan_row]
if use_bias:
X = np.concatenate([X, np.ones((1, X.shape[1]))], axis=0)
test_data = table[:, nan_col]
test_x = np.delete(test_data, nan_row, axis=0)
if use_bias:
test_x = np.concatenate([test_x, [1]])
w = np.linalg.inv(X @ X.T) @ X @ y
self.output_result(test_x, w)
def output_result(self, x, w):
result = x @ w
result_str = '{:.2f}'.format(result) + '\n = '
total_row_len = 0
for i, (w_elem, x_elem) in enumerate(zip(w, x)):
temp = '{:.2f}'.format(w_elem) + '*' + str(x_elem)
total_row_len += len(temp)
result_str += temp
if total_row_len > 40:
total_row_len = 0
result_str += '\n'
if i != len(w) - 1:
result_str += ' + '
self.frontend.set_result_str(result_str)
def check_table_format(self, table):
# no input
if len(table) == 0:
self.message = '入力が少ないです'
return False
row_length = [len(table[i]) for i in range(len(table))]
row_length = set(row_length)
# not in table format
if len(row_length) != 1:
self.message = '入力が少ないです'
return False
is_nan = np.isnan(table)
nan_row = np.where(is_nan.any(axis=1))[0]
nan_col = np.where(is_nan.any(axis=0))[0]
# too many unknowns
if len(nan_row) != 1 or len(nan_col) != 1:
self.message = '未知数の数が不正です'
return False
self.nan_position = (nan_row[0], nan_col[0])
return True
おわりに
問題を解く部分の実装はあんまり難しくないですね.
玉手箱の空欄推測問題をPythonで解いてみた
玉手箱の空欄推測問題をPythonで解いてみた 〜GUI編(Tkinter)〜
玉手箱の空欄推測問題をPythonで解いてみた 〜文字認識編(pyocr)〜
玉手箱の空欄推測問題をPythonで解いてみた 〜ソルバー編〜