@maru401
Revisions
Report this question
Subscribe question
Help us understand the problem. What is going on with this question?
Q&A

mnistデータの可視化

解決したいこと

python初心者です。
pythonを使ってmnistデータセットの学習をしています。

発生している問題・エラー

ValueError                                Traceback (most recent call last)
<ipython-input-44-4a873dc32b85> in <module>()
      2 plt.figure(figsize=(6,3))
      3 plt.subplot(1,2,1)
----> 4 plot_image(i, predict, test_label, test_image)
      5 plt.subplot(1,2,2)
      6 plot_value_array(i, predict,  test_label)

<ipython-input-42-6b66f0ca2cb5> in plot_image(i, predictions_array, true_label, img)
      8 
      9     predicted_label = np.argmax(predictions_array)
---> 10     if predicted_label == true_label:
     11         color = 'blue'
     12     else:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

または、問題・エラーが起きている画像をここにドラッグアンドドロップ。

該当するソースコード

←入力するとソースコードにシンタックスハイライトが付きます
import tensorflow as tf
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import pandas as pd

from tensorflow import keras

from keras.layers import Dense, LSTM,Flatten

from keras.datasets import mnist

(train_image,train_label),(test_image,test_label) = mnist.load_data()



"""データの確認"""

train_image.shape#28*28のデータ

train_label.shape#60000のラベル

plt.figure()
plt.imshow(test_image[0])
plt.colorbar()
plt.grid(False)
plt.show()

type(train_image)

df = pd.DataFrame(train_label)

df.describe()

df.info()

from keras import models
from keras.layers import Dense

model = models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.compile(
        optimizer = "adam",
        loss = "categorical_crossentropy",
        metrics=["accuracy"]
    )
model.summary()

train_image = train_image.astype("float32")/255
test_image = test_image.astype("float32")/255

test_image = test_image.reshape(10000,28,28)
train_image = train_image.reshape(60000,28,28)

from keras.utils import to_categorical

train_label = to_categorical(train_label)
test_label = to_categorical(test_label)

train_label.shape

history=model.fit(train_image,train_label,epochs=5)

test_loss, test_acc = model.evaluate(test_image,  test_label, verbose=2)

print('\nTest accuracy:', test_acc)

predict = model.predict(test_image)

np.argmax(predict[0])

np.argmax(test_label[0])

def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array)
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'

    plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                    100*np.max(predictions_array),
                                    class_names[true_label]),
                                    color=color)

def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1]) 
    predicted_label = np.argmax(predictions_array)

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

predict

i = 0
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predict, test_label, test_image)
plt.subplot(1,2,2)
plot_value_array(i, predict,  test_label)
plt.show()

エラーの箇所

plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predict, test_label, test_image)
plt.subplot(1,2,2)
plot_value_array(i, predict,  test_label)
plt.show()

解決してほしいこと

本来正解のラベルに棒グラフが作成されるはずなのですがエラーのせいで表示されません。
エラーの原因がわからないので、誰か解決してほしいです。よろしくお願いします。

0
1
Answer

エラーメッセージが、plot_imageの10行目を指していて、'配列をここに入れないでください'と書いてありますね。
コードを追っていくとplot_imageのif文の右辺が配列になっています。

def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array)
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'

    plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                    100*np.max(predictions_array),
                                    class_names[true_label]),
                                    color=color)

このメソッドの引数の true_label について考えると、このメソッドが渡してほしい配列は正答データのラベルの配列ではないでしょうか。(例えば認識したいクラス数が4で、テストデータの正答データが 1,2,0,3 であれば true_labelに [1, 2, 0, 3]が入っていてほしい)
しかし実際には to_categorical されたあとのone_hotの配列になっているため、plot_imageに渡されているデータは [[0,1,0,0],[0,0,1,0],[1,0,0,0],[0,0,0,1]]のような形になってしまっています。

問題を解決するには、選択肢は2つあってどちらかを選べば良いと思います。
1. to_categorical で test_label を上書きしない(one_hot エンコーディングされる前の形の配列と、されたあとの形の配列を、それぞれ別の変数で保持してplot_imageにはエンコーディングされる前の配列を渡す)
2. plot_imageのif文の前で、true_labelについてもnp.argmaxで正答のラベルを復元する

0
Help us understand the problem. What is going on with this answer?
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした