6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

shogunの使い方

Last updated at Posted at 2015-07-07

こちらで入手することのできる機械学習ライブラリ、SHOGUNの覚え書き。
導入方法はこちらに書いてあります。

###1.ラベル


2値ラベルであるBinaryLabelの例。-1か1で表される。配列かCSVファイルから作成することができる。

from modshogun import BinaryLabels

#ランダムに5つのラベルを生成
label = BinaryLabels(5)

label.get_num_labels() 
→ 5

label.get_values()
→ array([  2.00000000e+000,   2.00000000e+000,   1.38338381e-322,0.00000000e+000,   0.00000000e+000])

from modshogun import CSVFile

#あらかじめ用意したCSVファイルからでも作成可能
label_from_csv = BinaryLabels(CSVFile(file_path))

###2.特徴量


numpyの行列かCSVファイルから作成することができる。1特徴量は1行ではなく、1列で表現することに注意が必要

from modshogun import RealFeatures
import numpy as np

#3x3のランダム行列
feat_arr = np.random.rand(3, 3)
→ array([[ 0.02818103,  0.72093824,  0.92727711],
       [ 0.66853622,  0.14594737,  0.90522684],
       [ 0.97941639,  0.14188234,  0.80854797]])

#RealFeaturesの初期化
features = RealFeatures(feat_arr)

#特徴量の表示
features.get_feature_matrix(features)
→ array([[ 0.02818103,  0.72093824,  0.92727711],
       [ 0.66853622,  0.14594737,  0.90522684],
       [ 0.97941639,  0.14188234,  0.80854797]])

#特定の列の特徴量を取得
features.get_feature_vector(1)
→array([ 0.72093824,  0.14594737,  0.14188234])

#特徴の種類(行の数)
features.get_num_features()
→3

#特徴の個数(列の数)
features.get_num_vectors()
→3

from modshogun import CSVFile

#もちろんこれもCSVファイルから読み込みが可能。
feats_from_csv = RealFeatures(CSVFile(file_path))

###3.カーネル


カイ2乗カーネルでの例。

from modshogun import Chi2Kernel, RealFeatures, CSVFile

#訓練用データ
feats_train = RealFeatures(CSVFile(file_path))

#テスト用データ
feats_test = RealFeatures(CSVFile(file_path))

#カーネルの幅
width = 1.4

#size_cacheの設定
size_cache = 10

#カーネルの生成
kernel = Chi2Kernel(feats_train, feats_train, width, size_cache)

#カーネルの訓練
kernel.init(feats_train, feats_test)

###4.SVMLight


SVMLightを用いたサポートベクトルマシンによる分類

from modshogun import SVMLight, CSVFile, BinaryLabels, RealFeatures, Chi2Kernel

feats_train = RealFeatures(CSVFile(train_data_file_path))
feats_test = RealFeatures(CSVFile(test_data_file_path))

kernel = Chi2Kernel(feats_train, feats_train, 1.4, 10)

labels = BinaryLabels(CSVFile(label_traindat_path))
 
C = 1.2
epsilon = 1e-5
num_threads = 1
svm = SVMLight(C, kernel, labels)
svm.set_epsilon(epsilon)
svm.parallel.set_num_threads(num_threads)
svm.train()

kernel.init(feats_train, feats_test)
res = svm.apply().get_labels()

res
→array(結果のラベル)

###5.交差検定


CrossValidationクラスをimportして行います。CrossValidationの初期化には

  • 分類器(SVNLightやLibLinearなどのCMachineクラス)
  • 特徴量(RealFeatuersやDenseFeaturesなどのCFeaturesクラス)
  • ラベル(BinaryLabelsやMultiClassLabelsなどのCLabelsクラス)
  • データ分割の方法(CrossValidationSplittingなどのCSplittngStrategyクラス)
  • 評価基準(ContingencyTableEvaluationなどのCEvaluationクラス)
    を引数として渡します。
from modshogun import LibLinear, BinalyLabels, RealFeatures, CrossValidationSplitting, ContingencyTableEvaluation, CSVFile, ACCURACY

#分類器
classifier = LibLinear(L2R_L2LOSS_SVC)
#特徴量
features = RealFeatures(CSVFile(feature_file_path))
#ラベル
labels = BinalyLabels(CSVFile(label_file_path))


#SplittingStrategyは、データをどのように分割するかを指定できるっぽい。あまり詳しくはわからないです。この例では5分割しています。
splitting_strategy = CrossValidationSplitting(labels, 5)

#評価基準クラス。ACCURACYはただのEContingencyTableMeasureTypeに宣言されている定数です。
evaluation_criterium = ContingencyTableEvaluation(ACCURACY)

#クロスバリデーションクラス。
cross_validation = CrossValidation(classifier, features, labels. splitting_strategy, evaluation_criterium)
cross_validation.set_autolock(False)

#繰り返しの数の設定
cross_validation.set_num_runs(10)

#95%信頼区間の設定?よくわからない
cross_validation.set_conf_int_alpha(0.05)

#返り値はCEvaluationResultクラス
result = cross_validation.evaluate()

#交差検定の結果の平均値を取得できます。
print result.mean

#その他もろもろ全てを出力したい場合はこちら
print result.print_result()

###6.グリッドサーチ


ここまで出来ればグリッドサーチはかなり簡単にできます。
CModelSelectionクラスのGridSearchModelSelectionを

  • CrossValidation
  • 変化させるパラメータのセット(ModelSelectionParameters)

を渡して初期化すればもうグリッドサーチができます。

---LibLinearでCrossValidationクラスを初期化するところまでは省略---

from modshogun import ModelSelectionParameters, R_EXP
from modsghoun import GridSearchModelSelection

#変化させるパラメータを格納するオブジェクト
param_tree_root = ModelSelectionParameters()

#パラメータC1
C1 = ModelSelectionParameters("C1")
param_tree_root.append_child(c1)

build_values()で最小値、最大値、ステップ(パラメータの増加量)を設定する。R_EXP(指数),R_LOG(対数),R_LINEAR(線形)の3種類あるが、詳細は不明。
c1.build_values(-1.0, 0.0, R_EXP)

c2 = ModelSelectionParameters("C2")
param_tree_root.append_child(c2)
c2.build_values(-1.0, 0.0, R_EXP)

#ここでprint_tree()を実行すると、param_tree_rootは木構造になっていることがわかる。
param_tree_root.print_tree()
→root with
	 with values: vector=[0.5,1]
	 with values: vector=[0.5,1]

#グリッドサーチクラスの生成
model_selection = GridSearchModelSelection(cross_validation, param_tree_root)

#これで自動的に最適なパラメータを決定し、CParameterCombinationクラスのオブジェクトを返します。また、Trueを引数として渡すと、各パラメータの組み合わせと、その結果も出力してくれます。
best_parameters = model_selection.select_model()

#返ってきた最良のパラメータを分類器やモデルのパラメータとして適用することも可能です。
best_parameters.apply_to_machine(classifier)
result = cross_validation.evaluate()

###7.作成したモデルの保存、読み込み


ほとんど全てのクラスのもとになっているCSGObjectが持っている関数save_serializable()とload_serializable()を利用して、オブジェクトの保存、読み込みが可能です。

from modshogun import SerializableAsciiFile
from modshogun import MulticlassLabels
from numpy import array

save_labels = MulticlassLabels(array([1.0, 2, 3]))

#ファイル名の設定 csvやascに対応
save_file = SerializableAsciiFile("foo.csv", "w")
#ファイルの保存
save_labels.save_serializable(save_file)

load_file = SerializableAsciiFile("foo.csv", "r")
load_labels = MulticlassLabels()
load_labels.load_serializable(load_file)
→[ 1.  2.  3.]

###8.ログを吐く


各オブジェクトに関して、ログを吐くことができます。デバッグログはMSG_DEBUG、エラーログのみの場合はMSG_ERRORを引数に渡す。 EMessageTypeで宣言されています。

from modshogun import MSG_DEBUG, MSG_ERROR
from modshogun import Chi2Kernel
from modshogun import LibSVM

kernel = Chi2Kernel()
svm = LibSVM()

kernel.io.set_loglevel(MSG_DEBUG)
svm.io.set_loglevel(MSG_ERROR)

###おわりに


いろいろと雑な感じになってしまってるので、何か要望がありましたらコメントください。

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?