15
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

備忘録@Python ORセミナー: scikit-learn

Last updated at Posted at 2014-09-05

pythonでも指折りのライブラリだと思ってる.
機械学習用のライブラリ.入っている学習は以下のとおり(一例).

  • 教師あり学習
    • 最近傍法,一般化線形モデル,線形判断分析,SVM,決定木,ランダムフォレスト,ナイーブベイズなど
  • 教師なし学習
    • 混合ガウスモデル,主成分分析,因子分析,独立成分分析,クラスタリング,隠れマルコフモデルなど
  • その他
    • クロスバリデーション,グリッドサーチ,Accuracyなど

※チートシートはこっち

教師あり学習

Support Vector Machine

回帰分析

# まずは学習データの作成
>>> import numpy
>>> np.seed(0) # 乱数のシード固定
>>> x = numpy.sort(5 * numpy.random.rand(40, 1), axis=0)
>>> y = numpy.sin(x).ravel()
>>> y[::5] += 3 * (0.5 - numpy.random.rand(8))
>>> plot(x, y, 'o')

回帰分析の時はSVRを使う(Support Vector Regressionの略かと).
引数のオプションについて

C (default = 1.0)

  • 罰則項のパラメータ
  • 大きいとマージンを許容しない(ハードマージン)、小さいと許容する

kernel (default = rbf)

  • カーネル関数のタイプ
  • 線形: linear, 多項式: poly, RBF(ガウス): rbf, シグモイド: sigmoid, プレコンピューテッド: precomputed

gamma (default = 0.0)

  • RBF、多項式カーネルの係数

degree (default = 2)

  • RBF、多項式、シグモイドカーネル関数の次数
>>> from sklearn.svm import SVR
# 学習器の作成
>>> svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
>>> svr_lin = SVR(kernel='linear', C=1e3)
>>> svr_poly = SVR(kernel='poly', C=1e3, degree=2)
# fitで学習,predictで予測
>>> y_rbf = svr_rbf.fit(x, y).predict(x)
>>> y_lin = svr_lin.fit(x, y).predict(x)
>>> y_poly = svr_poly.fit(x, y).predict(x)

分類

いわゆるSVMはこっちだと思う.
scikit-learnではSVC(Support Vector Classifierの略)を使う.

# 学習データ作成
>>> numpy.random.seed(0)
>>> X = numpy.random.randn(300, 2)
>>> Y = numpy.logical_xor(X[:,0]>0, X[:,1]>0)
from sklearn.svm import SVC
# 分類器の作成
>>> clf = SVC(kernel='rbf', C=1e3, gamma=0.1)
# 学習
>>> clf.fit(X, Y)
# 決定関数までの距離を計算
>>> xx, yy = np.meshgrid(linspace(-3, 3, 500), linspace(-3, 3, 500))
>>> Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
>>> Z = Z.reshape(xx.shape)
#グラフ化
>>> imshow(Z, interpolation='nearest', extent=[xx.min(),
...                                            xx.max(),
...                                            yy.min(),
...                                            yy.max()],
...                                            aspect='auto',
...                                            origin='lower',
...                                            cmap=cm.PuOr_r)
>>> ctr = contour(xx, yy, Z, levels=[0], linetypes='--')
>>> scatter(X[:, 0], X[:, 1], c=Y, cmap=cm.Paired)
>>> axis([xx.min(), xx.max(), yy.min(), yy.max()])
>>> show()

imshow(): 配列をグラフ化する
引数オプション:
interpolation

  • グラフ処理の際の補完
    • 'nearest' -

extent

  • 範囲を指定
    • [水平方向の最小値, 水平方向の最大値, 鉛直方向の最小値, 鉛直方向の最大値]

aspect

  • アスペクト比の調整

origin

  • 基準点を設定
    • 'lower' - Z[0,0]を左下のコーナーに合わせる

cmap

  • カラーマップを指定

教師なし学習

k-means法

一例でk-means法を示す.

>>> import sklearn.datasets, sklearn.cluster
>>> #  IRISデータの読込
>>> d = sklearn.datasets.load_iris()
>>> km = sklearn.cluster.KMeans(3)
>>> km.fit(d.data)
>>> for i, e in enumerate(d.data):
...    scatter(e[0], e[2], c='rgb'[km.labels_[i]])

#その他

残りは公式ドキュメントで.

15
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?