Help us understand the problem. What is going on with this article?

Lua版 ゼロから作るDeep Learning その5.5[pklファイルをLua Torchで使えるようにする]

More than 3 years have passed since last update.

過去記事

Lua版 ゼロから作るDeep Learning その1[パーセプトロンの実装]
Lua版 ゼロから作るDeep Learning その2[活性化関数]
Lua版 ゼロから作るDeep Learning その3[3層ニューラルネットワークの実装]
Lua版 ゼロから作るDeep Learning その4[ソフトマックス関数の実装]
Lua版 ゼロから作るDeep Learning その5[MNIST画像の表示]

はじめに

 今回は本筋とは関係がないですが、pkl ファイルに格納されたデータをLua で使えるようにするやり方をご紹介します。
 

STEP1 : pkl ファイルをnpz ファイルに変換する

 以下のスクリプトで行うことができます。

pkl2npz.py
#!/usr/local/bin/python3
# coding: utf-8

"""
pklファイルの内容をnpzファイルへ出力する。
"""
__author__ = "Kazuki Nakamae <kazukinakamae@gmail.com>"
__version__ = "0.00"
__date__ = "22 Jun 2017"

import sys
import numpy as np
import pickle

def pkl2npz(infn, outfn):
    """
    @function   pkl2npz();
    pklファイルの内容をnpzファイルへ出力する。
    @param  {string} infn :   入力ファイル名
    @param  {string} outfn :   出力ファイル名
    """

    with open(infn, 'rb') as f:
            ndarr = pickle.load(f)
            np.savez(outfn, W1=ndarr['W1'],W2=ndarr['W2'],W3=ndarr['W3'],b1=ndarr['b1'],b2=ndarr['b2'],b3=ndarr['b3'])

if __name__ == '__main__':
    argvs = sys.argv
    argc = len(argvs)

    if (argc != 3):   # Checking input
        print("USAGE : python3 pkl2npz.py <INPUT_PKLFILE> <OUTPUT_NPZFILE>")
        quit()

    pkl2npz(str(argvs[1]),str(argvs[2]))
quit()
pkl2npz.pyの実行
$ python3 pkl2npz.py sample_weight.pkl sample_weight.npz

 
 一つ注意点ですが、要素に関してはあらかじめ把握しとかなくてはなりません。今回の場合は3層NN の重み(W1, W2, W3) とバイアス(b1, b2, b3) を格納したpkl ファイルとなります。
 

STEP2 : npz ファイルを読み込む

 では作成した sample_weight.npz をLua 上で読み込みましょう。何も手間はかかりません。以下の方がそのためのパッケージ(npy4th)を作成してくれています。
htwaijry/npy4th

npy4thのインストール
$ git clone https://github.com/htwaijry/npy4th.git
$ cd npy4th
$ luarocks make

 使い方は簡単でloadnpz([ファイル名])を入れるだけで読み込めます。

loadnpz()の使い方
npy4th = require 'npy4th'

-- read a .npz file into a table
tbl = npy4th.loadnpz('sample_weight.npz')

print(tbl["W1"])
出力結果
Columns 1 to 6
-7.4125e-03 -7.9044e-03 -1.3075e-02  1.8526e-02 -1.5346e-03 -8.7649e-03
-1.0297e-02 -1.6167e-02 -1.2284e-02 -1.7926e-02  3.3988e-03 -7.0708e-02
-1.3092e-02 -2.4475e-03 -1.7722e-02 -2.4240e-02 -2.2041e-02 -5.0149e-03
-1.0008e-02  1.9586e-02 -5.6170e-03  3.8307e-02 -5.2507e-02 -2.3568e-02
(略)
 1.1210e-02  1.0272e-02
-1.2299e-02  2.4070e-02
 7.4309e-03 -4.0211e-02
[torch.FloatTensor of size 784x50]

 ちゃんとTensor 型に変換されていますね。今回は紹介しませんでしたが、npy ファイルでも同じような読み込みが可能となっています。
 しかし読み込みはしてくれますが、そのまま使えるmatrix の形になっているかは保証されないのに注意してください。どんな形かはnumpy 側でどんな形式だったかによります。必要とあらばresize() などで整形する作業がこのあと必要になってくるでしょう。
 

おわりに

 日本では人気のないTorch ですが、このようにnumpy の資源が使えると利用しやすいのではないかと思います。
 以上です。ありがとうございました。

 

Kazuki-Nakamae
実験生物学系の研究室に所属しています。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away