1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

NMFをscikit-learnとSPAMSで動かしてみる

Posted at

なにをしたいか

  • NMFの使い方を確認する
  • scikit-learnとSPAMSのそれぞれでのNMFを比較する

NMF

Non-negative Matrix Factorizationです。こことか、その他いろいろな記事で非常に丁寧に説明されています。

scikit-learnでの使い方

  • オフィシャルのサンプルコードをなぞります。
  • $ X \in {R}^{6\times2} $なので、$ W \in {R}^{6\times2} $で、$ H \in {R}^{2\times2} $です。
  • n_componentsは、↑の2に当たります。
NMF(scikit-learn)
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [3, 1.2], [4, 1], [5, 0.8], [6, 1]])
>>> from sklearn.decomposition import NMF
>>> model = NMF(n_components=2, init='random', random_state=0)
>>> W = model.fit_transform(X)
>>> H = model.components_
>>> # ここから結果の確認
>>> X_ = np.dot(W, H)
>>> print(X)  # 元のXの値
[[1.  1. ]
 [2.  1. ]
 [3.  1.2]
 [4.  1. ]
 [5.  0.8]
 [6.  1. ]]
>>> print(X_)  # NMFの結果得られたX
[[1.00063558 0.99936347]
 [1.99965977 1.00034074]
 [2.99965485 1.20034566]
 [3.9998681  1.0001321 ]
 [5.00009002 0.79990984]
 [6.00008587 0.999914  ]]
>>> print(model.inverse_transform(W))  # inverse_transformでもNMFの結果を見られる
[[1.00063558 0.99936347]
 [1.99965977 1.00034074]
 [2.99965485 1.20034566]
 [3.9998681  1.0001321 ]
 [5.00009002 0.79990984]
 [6.00008587 0.999914  ]]

SPAMSでの使い方

  • 公式がサンプルコードを提供しています。
  • ここでは、scikit-learnと同じXを使用します。
  • Sourceをみると、trainDLとlassoで同じことが出来るようです
NMF(SPAMS)
>>> import numpy as np
>>> X = np.asfortranarray([[1, 1], [2, 1], [3, 1.2], [4, 1], [5, 0.8], [6, 1]])
>>> param = { 'K' : 2, 'numThreads' : 4,  "iter": 1000, "return_lasso": True}
>>> W, H = spams.nmf(X, **param)  # Xはfortranarrayである必要があります
>>> H = np.array(H.todense())  # Compressed Sparse Column formatで返ってきます
>>> print(np.dot(W, H))
[[1.  1. ]
 [2.  1. ]
 [3.  1.2]
 [4.  1. ]
 [5.  0.8]
 [6.  1. ]]

速度比較

scikit-learn
%%timeit
model = NMF(n_components=2, init='random', random_state=0, max_iter=1000)
W = model.fit_transform(X)

729 µs ± 2.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

spams
%%timeit
W, H = spams.nmf(X, **param)

5.47 ms ± 302 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

たぶんこれで比較になっていると思います

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?