Chainerを使うために幾つかサイトを参考にしていました.
ところが,自前の画像を学習するためにtrain_imagenet.pyを実行すると,以下のようなエラーが発生しました.
エラー
cPickle.UnpicklingError: invalid load key,
該当箇所は↓のコードのpickle.loadという関数による非Pickle化処理
train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
mean_image = pickle.load(open(args.mean, 'rb'))
引数のargs.meanの値はmean.npyというファイルなので,このファイルの出処を探してみると...
compute_mean.py
#!/usr/bin/env python
import argparse
import os
import sys
import numpy
from PIL import Image
import six.moves.cPickle as pickle
parser = argparse.ArgumentParser(description='Compute images mean array')
parser.add_argument('dataset', help='Path to training image-label list file')
parser.add_argument('--root', '-r', default='.',
help='Root directory path of image files')
parser.add_argument('--output', '-o', default='mean.npy',
help='path to output mean array')
args = parser.parse_args()
sum_image = None
count = 0
for line in open(args.dataset):
filepath = os.path.join(args.root, line.strip().split()[0])
image = numpy.asarray(Image.open(filepath)).transpose(2, 0, 1)
if sum_image is None:
sum_image = numpy.ndarray(image.shape, dtype=numpy.float32)
sum_image[:] = image
else:
sum_image += image
count += 1
sys.stderr.write('\r{}'.format(count))
sys.stderr.flush()
sys.stderr.write('\n')
mean = sum_image / count
pickle.dump(mean, open(args.output, 'wb'), -1)
numpy.ndarray関数で生成したオブジェクトを,pickle.dump関数でmean.npyというファイルに出力しているようです.
つまり,mean.npyの実体はNumPy配列のバイトストリームのようです.
なので,train_imagenet.pyでmean.npyを非Pickle化して読み込むのではなく,NumPy配列として読み込むように修正しました.
train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
# mean_image = pickle.load(open(args.mean, 'rb')) ←非Pickle化して読み込むとcPickle.UnpicklingError
mean_image = np.load(args.mean) # NumPy配列として読み込む
するとなんとか読み込めました.