LoginSignup
3
5

More than 5 years have passed since last update.

ChainerでcPickle.UnpicklingError

Last updated at Posted at 2016-08-12

Chainerを使うために幾つかサイトを参考にしていました.

ところが,自前の画像を学習するためにtrain_imagenet.pyを実行すると,以下のようなエラーが発生しました.

エラー
cPickle.UnpicklingError: invalid load key, 

該当箇所は↓のコードのpickle.loadという関数による非Pickle化処理

train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
mean_image = pickle.load(open(args.mean, 'rb'))

引数のargs.meanの値はmean.npyというファイルなので,このファイルの出処を探してみると...

compute_mean.py
#!/usr/bin/env python
import argparse
import os
import sys

import numpy
from PIL import Image
import six.moves.cPickle as pickle


parser = argparse.ArgumentParser(description='Compute images mean array')
parser.add_argument('dataset', help='Path to training image-label list file')
parser.add_argument('--root', '-r', default='.',
                    help='Root directory path of image files')
parser.add_argument('--output', '-o', default='mean.npy',
                    help='path to output mean array')
args = parser.parse_args()

sum_image = None
count = 0
for line in open(args.dataset):
    filepath = os.path.join(args.root, line.strip().split()[0])
    image = numpy.asarray(Image.open(filepath)).transpose(2, 0, 1)
    if sum_image is None:
        sum_image = numpy.ndarray(image.shape, dtype=numpy.float32)
        sum_image[:] = image
    else:
        sum_image += image
    count += 1
    sys.stderr.write('\r{}'.format(count))
    sys.stderr.flush()

sys.stderr.write('\n')

mean = sum_image / count
pickle.dump(mean, open(args.output, 'wb'), -1)

numpy.ndarray関数で生成したオブジェクトを,pickle.dump関数でmean.npyというファイルに出力しているようです.
つまり,mean.npyの実体はNumPy配列のバイトストリームのようです.

なので,train_imagenet.pyでmean.npyを非Pickle化して読み込むのではなく,NumPy配列として読み込むように修正しました.

train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
# mean_image = pickle.load(open(args.mean, 'rb')) ←非Pickle化して読み込むとcPickle.UnpicklingError
mean_image = np.load(args.mean) # NumPy配列として読み込む

するとなんとか読み込めました.

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5