エラーメッセージが、plot_image
の10行目を指していて、'配列をここに入れないでください'と書いてありますね。
コードを追っていくとplot_image
のif文の右辺が配列になっています。
def plot_image(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
100*np.max(predictions_array),
class_names[true_label]),
color=color)
このメソッドの引数の true_label
について考えると、このメソッドが渡してほしい配列は正答データのラベルの配列ではないでしょうか。(例えば認識したいクラス数が4で、テストデータの正答データが 1,2,0,3 であれば true_labelに [1, 2, 0, 3]
が入っていてほしい)
しかし実際には to_categorical されたあとのone_hotの配列になっているため、plot_imageに渡されているデータは [[0,1,0,0],[0,0,1,0],[1,0,0,0],[0,0,0,1]]
のような形になってしまっています。
問題を解決するには、選択肢は2つあってどちらかを選べば良いと思います。
1. to_categorical で test_label を上書きしない(one_hot エンコーディングされる前の形の配列と、されたあとの形の配列を、それぞれ別の変数で保持してplot_imageにはエンコーディングされる前の配列を渡す)
2. plot_imageのif文の前で、true_labelについてもnp.argmaxで正答のラベルを復元する
Like!