4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Lua版 ゼロから作るDeep Learning その5.5[pklファイルをLua Torchで使えるようにする]

Last updated at Posted at 2017-06-21

過去記事

Lua版 ゼロから作るDeep Learning その1[パーセプトロンの実装]
Lua版 ゼロから作るDeep Learning その2[活性化関数]
Lua版 ゼロから作るDeep Learning その3[3層ニューラルネットワークの実装]
[Lua版 ゼロから作るDeep Learning その4[ソフトマックス関数の実装]]
(http://qiita.com/Kazuki-Nakamae/items/20e53a02a8b759583d31)
Lua版 ゼロから作るDeep Learning その5[MNIST画像の表示]

はじめに

 今回は本筋とは関係がないですが、pkl ファイルに格納されたデータをLua で使えるようにするやり方をご紹介します。
 
STEP1 : pkl ファイルをnpz ファイルに変換する

 以下のスクリプトで行うことができます。

pkl2npz.py
#!/usr/local/bin/python3
# coding: utf-8

"""
pklファイルの内容をnpzファイルへ出力する。
"""
__author__ = "Kazuki Nakamae <kazukinakamae@gmail.com>"
__version__ = "0.00"
__date__ = "22 Jun 2017"

import sys
import numpy as np
import pickle

def pkl2npz(infn, outfn):
    """
    @function   pkl2npz();
    pklファイルの内容をnpzファイルへ出力する。
    @param  {string} infn :   入力ファイル名
    @param  {string} outfn :   出力ファイル名
    """

    with open(infn, 'rb') as f:
            ndarr = pickle.load(f)
            np.savez(outfn, W1=ndarr['W1'],W2=ndarr['W2'],W3=ndarr['W3'],b1=ndarr['b1'],b2=ndarr['b2'],b3=ndarr['b3'])

if __name__ == '__main__':
    argvs = sys.argv
    argc = len(argvs)

    if (argc != 3):   # Checking input
        print("USAGE : python3 pkl2npz.py <INPUT_PKLFILE> <OUTPUT_NPZFILE>")
        quit()

    pkl2npz(str(argvs[1]),str(argvs[2]))
quit()
pkl2npz.pyの実行
$ python3 pkl2npz.py sample_weight.pkl sample_weight.npz

 
 一つ注意点ですが、要素に関してはあらかじめ把握しとかなくてはなりません。今回の場合は3層NN の重み(W1, W2, W3) とバイアス(b1, b2, b3) を格納したpkl ファイルとなります。
 
STEP2 : npz ファイルを読み込む

 では作成した sample_weight.npz をLua 上で読み込みましょう。何も手間はかかりません。以下の方がそのためのパッケージ(npy4th)を作成してくれています。
htwaijry/npy4th

npy4thのインストール
$ git clone https://github.com/htwaijry/npy4th.git
$ cd npy4th
$ luarocks make

 使い方は簡単でloadnpz([ファイル名])を入れるだけで読み込めます。

loadnpz()の使い方
npy4th = require 'npy4th'

-- read a .npz file into a table
tbl = npy4th.loadnpz('sample_weight.npz')

print(tbl["W1"])
出力結果
Columns 1 to 6
-7.4125e-03 -7.9044e-03 -1.3075e-02  1.8526e-02 -1.5346e-03 -8.7649e-03
-1.0297e-02 -1.6167e-02 -1.2284e-02 -1.7926e-02  3.3988e-03 -7.0708e-02
-1.3092e-02 -2.4475e-03 -1.7722e-02 -2.4240e-02 -2.2041e-02 -5.0149e-03
-1.0008e-02  1.9586e-02 -5.6170e-03  3.8307e-02 -5.2507e-02 -2.3568e-02
(略)
 1.1210e-02  1.0272e-02
-1.2299e-02  2.4070e-02
 7.4309e-03 -4.0211e-02
[torch.FloatTensor of size 784x50]

 ちゃんとTensor 型に変換されていますね。今回は紹介しませんでしたが、npy ファイルでも同じような読み込みが可能となっています。
 しかし読み込みはしてくれますが、そのまま使えるmatrix の形になっているかは保証されないのに注意してください。どんな形かはnumpy 側でどんな形式だったかによります。必要とあらばresize() などで整形する作業がこのあと必要になってくるでしょう。
 
おわりに

 日本では人気のないTorch ですが、このようにnumpy の資源が使えると利用しやすいのではないかと思います。
 以上です。ありがとうございました。

 

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?