学習データの前処理にデータの実数範囲を変更する正規化と呼ばれる処理があります。
正規化の実装はscikit-learn(以下sklearn)にfit_transformと呼ばれる関数が用意されています。
今回は学習データと検証データに対して正規化を行う実装をサンプルコードと共に共有します。
sklearn正規化関数
sklearnに用意されている正規化関数は主に3種類、2段階のプロセスがあります。
- パラメータの算出
- パラメータを用いた変換
fit()
入力データから標準偏差や最大・最小値を算出しパラメータを保存
transform()
fit関数から算出されたパラメータを用いてデータを変換
fit_transform()
上記の処理を連続的に実行する
なぜ3種類の関数があるか?
あるデータに対して正規化をするのであればfit_transorm関数を用いて同時にパラメータの算出とデータ変換を行えばよいはず。。
しかし、学習の際に前処理としてデータを変換する場合、学習用データと検証用データで同様のパラメータ(fit関数の結果)を用いる必要があります。※サンプルコードで簡易な例を表示します。
そのため、あるデータに対してパラメータを算出するfit()と算出されたパラメータを用いて変換を行うtransform関数が用意されています。
正規化種類
sklearnのリファレンスで調べたところ27種類もあるようです。私は2、3種類しか用いたことがありませんが興味があればご参照ください。
API Reference sklearn.preprocessing scikit-learn 0.19.2 documentation
よく利用される変換手法
・MinMaxScaler() # データの最大・最小値を定義
・StandardScaler() # 標準化
サンプルコード
以下、sklearnを用いた正規化のサンプルです。各行に処理の内容をコメントしています。
サンプルコードの手順としては、
- 正規化手法、テスト用データ定義
- fit_transformによる変換
- パラメータ保存->読み込み
- テスト用データ定義
- 保存パラメータによるテストデータ変換(transform)
- テストデータに対してデータ変換(fit_transform)
# importしていない場合都度、pip installしてください。
from sklearn import preprocessing
import numpy as np
import pickle
# 正規化手法定義 MinMaxScaler(0<=data<=1)
mmscaler = preprocessing.MinMaxScaler()
# 学習用生データ定義
train_raw = np.array(list(range(11)))
print (train_raw) # [ 0 1 2 3 4 5 6 7 8 9 10]
# 学習用データでfit_transform
train_transed = mmscaler.fit_transform(train_raw.reshape(-1,1))
# 変換結果表示
# 0から10のデータが0から1に変換されている
print (train_transed.flatten()) # [0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
# fitによるパラメータをバイナリ形式で保存
# 通常、学習コードと検証コードは分けて実装するためパラメータを保存する方法として紹介
pickle.dump(mmscaler, open('./scaler.sav', 'wb'))
# 上記が別関数での実装と仮定して学習データで保存したfitパラメータを読み込み(バイナリファイル)
save_scaler = pickle.load(open('./scaler.sav', 'rb'))
# パラメータ詳細確認
print(save_scaler,type(save_scaler)) # MinMaxScaler() <class 'sklearn.preprocessing._data.MinMaxScaler'>
# テスト用データ定義
test_raw = np.array(list(range(100)))
print (test_raw)
'''print (test_raw)
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99]
'''
# 保存したパラメータを用いて変換(tranform)
save_scaler_transed = save_scaler.transform(test_raw.reshape(-1,1))
print (save_scaler_transed.flatten())
# 学習データの重みを用いているのでデータ範囲が0から9.9になる
'''print (save_scaler_transed.flatten())
[0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. 1.1 1.2 1.3 1.4 1.5 1.6 1.7
1.8 1.9 2. 2.1 2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9 3. 3.1 3.2 3.3 3.4 3.5
3.6 3.7 3.8 3.9 4. 4.1 4.2 4.3 4.4 4.5 4.6 4.7 4.8 4.9 5. 5.1 5.2 5.3
5.4 5.5 5.6 5.7 5.8 5.9 6. 6.1 6.2 6.3 6.4 6.5 6.6 6.7 6.8 6.9 7. 7.1
7.2 7.3 7.4 7.5 7.6 7.7 7.8 7.9 8. 8.1 8.2 8.3 8.4 8.5 8.6 8.7 8.8 8.9
9. 9.1 9.2 9.3 9.4 9.5 9.6 9.7 9.8 9.9]
'''
# テストデータを用いてパラメータ算出+変換(fit_tranform)
test_fit_transed = mmscaler.fit_transform(test_raw.reshape(-1,1))
# テスト用データからパラメータを算出しているためデータ範囲が0から1になる
print (test_fit_transed.flatten())
'''print (test_fit_transed.flatten())
[0. 0.01010101 0.02020202 0.03030303 0.04040404 0.05050505
0.06060606 0.07070707 0.08080808 0.09090909 0.1010101 0.11111111
0.12121212 0.13131313 0.14141414 0.15151515 0.16161616 0.17171717
0.18181818 0.19191919 0.2020202 0.21212121 0.22222222 0.23232323
0.24242424 0.25252525 0.26262626 0.27272727 0.28282828 0.29292929
0.3030303 0.31313131 0.32323232 0.33333333 0.34343434 0.35353535
0.36363636 0.37373737 0.38383838 0.39393939 0.4040404 0.41414141
0.42424242 0.43434343 0.44444444 0.45454545 0.46464646 0.47474747
0.48484848 0.49494949 0.50505051 0.51515152 0.52525253 0.53535354
0.54545455 0.55555556 0.56565657 0.57575758 0.58585859 0.5959596
0.60606061 0.61616162 0.62626263 0.63636364 0.64646465 0.65656566
0.66666667 0.67676768 0.68686869 0.6969697 0.70707071 0.71717172
0.72727273 0.73737374 0.74747475 0.75757576 0.76767677 0.77777778
0.78787879 0.7979798 0.80808081 0.81818182 0.82828283 0.83838384
0.84848485 0.85858586 0.86868687 0.87878788 0.88888889 0.8989899
0.90909091 0.91919192 0.92929293 0.93939394 0.94949495 0.95959596
0.96969697 0.97979798 0.98989899 1. ]
'''
おわりに
正規化について学んだ内容を備忘録としてこちらの記事にしました。調べる前はfit_transform()とtransform()の違いに全く気が付きませんでした。。前処理の重要な変換であり、検証するデータの精度にも影響する箇所です。パラメータが不手際で再利用されるケースなどがないことを願っております。
この記事を作成するに当たり、先人の知恵をお借りしました。
後日、記載させていただきます。
ご一読ありがとうございました。LGTMも良ければお願いします!