はじめに
TensorFlowでCNN(今回は画像分類)やろうと調べると先駆者の方々はだいたい画像読み込みにOpenCVを使っているのですが、Python3だとcv2が動かない的な記事がたくさんヒットするわけです(cv3なら動いた疑惑……)。
ということでOpenCV使わずに画像読み込もうとしたらハマったので備忘録(これやってたのが随分前のことで記憶が曖昧ですがご了承ください)。
適当な解説
まずは画像ファイルパスを正解ラベルが記されたCSVを読み込みます。
# CSV読み込み
trains = np.loadtxt(csv_pass, delimiter=",", dtype=str)
CSVの中身はこんな感じ。
train/hoge1.jpg,0
train/hoge2.jpg,2
train/hoge3.jpg,1
以下、画像と正解ラベルを読み込む関数です。
np.loadtxtで読み込んだデータとセッションオブジェクトを渡してあげます。
def get_image_and_label(sess, csv_lines):
images = []
labels = []
for train in csv_lines:
#string型なのにバイト記号bが入っているので文字列削除(なぜ?)
delete = "¥'b¥'"
train[0]=train[0].strip(delete)
#画像読み込み
val = tf.cast(train[0], dtype=tf.string)
jpeg_r = tf.read_file(val)
image = tf.image.decode_jpeg(jpeg_r, channels=channel_num) #tf.image.decode_imageはshape型返さないのでダメらしい
image = tf.image.resize_images(image, [IMAGE_SIZE, IMAGE_SIZE])
image = tf.reshape(image, [-1]) #-1を指定した場合には次元が削減されflatten
image = tf.cast(image,dtype=np.float32)
image = image/255.0 #正規化
image_val = sess.run(image)
images.append(image_val)
#ラベル読み込み
tmp = np.zeros(NUM_CLASSES)
delete = "¥'b¥'" #同じく"b"を削除
lab=train[1].strip(delete)
index=int(lab) #正解ラベルのインデックス
tmp[index] = 1
labels.append(tmp)
return np.asarray(images, dtype=np.float32), np.asarray(labels, dtype=np.float32)
なぜかCSVから読み込んだデータがstring型なのにbyte列を示す記号が付いていたりしてハマりました。自分は雑魚で理由が分からなかったので無理やり該当文字列を削除して何とか画像を読み込みました(原因ご存知の方いましたら教えてください)。
結論は普通にKerasとか使った方が良い。
参考
ほぼコピペでエラー吐いてた箇所を修正して使わせていただきました。
https://qiita.com/kurimoto/items/0da3c900fcf790964619