学習中に中で何が起きているかちょっと覗いてみましょう。データは、引き続きMNISTです。
中間層-なし
まずは、中間層なしの場合を見てみましょう。
以下のネットワーク図になります。(バイアスは省略しています。)
実行
前回のプログラムで以下の設定にて実行します。
# 学習回数
epoch = 50
# ノード数
mds = [[]]
# 学習率
lrs = [0.1]
# バッチサイズ
batch_sizes = [100]
50エポック後、学習データに対する正解率は、約93%、テストデータに対する正解率は、約92.5%です。
重さ
50エポック後の重さを図にしてみましょう。
重さは、全部で$ 784 \times 10 $個あります。$ w_{1 i}^{(1)} $、$ w_{2 i}^{(1)} $の順に$ w_{10 i}^{(1)} $まで表示します。$ i $は、784個、$ 28 \times 28 $にreshapeし表示します。
正を赤、負を青で表します。色が濃いほど大きい値です。
import matplotlib.pyplot as plt
plt.figure(figsize=(20,6))
for i in range(10):
plt.subplot(2,5,i+1)
plt.title(str(i))
plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -1, vmax = 1)
plt.colorbar()
plt.show()
何か見えてきますか?
0は、真ん中の青が目立ちます。真ん中に文字があれば負の要因となります。
1は、逆に真ん中が赤です。真ん中以外は青が多く負の要因となります。
2は、左側が青で、左側がつながっていないことが要因になります。
3以降も何をもとに数字を判断しているかわかりますね。
各エポックの重みを保存しておき、エポックごとにどのように変化していくか見てみましょう。
0の場合です。
最初は、学習前の重さでランダムの値です。1エポック目から輪郭が見えています。
u
uの値を表示するため、u,zを取得できるように予測関数を変更します。
def predict2(x, W, b):
# 層数
layer = len(W)
# 順伝播
u = {}
z = {}
u[1] = affine(x, W[1], b[1])
for i in range(1, layer):
z[i] = relu(u[i])
u[i+1] = affine(z[i], W[i+1], b[i+1])
y = softmax(u[layer])
return y, u, z
呼び出し元も変更します。
y_train, u_train, z_train = predict2(nx_train, W, b)
y_test, u_test, z_test = predict2(nx_test, W, b)
重さとデータを掛けた図を示します。テストデータの5番目まで表示してみます。
import matplotlib.pyplot as plt
for j in range(5):
print(j)
plt.figure(figsize=(20,7))
for i in range(10):
plt.subplot(2,5,i+1)
plt.title("u1[{0:d}]={1:.2f}".format(i, u_test[1][j,i]))
plt.imshow((W[1][:,i]*nx_test[j]).reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.colorbar()
plt.show()
結果を順に見てみましょう。
- 7
u1[7]の画像が一番多く赤く反応しています。
- 2
u1[2]、u1[3]とも上部は同じように反応していますが、下部で差が出ているようです。
- 1
u1[1]の中央が大きく反応しています。
- 0
u1[0]が大きく反応しています。
- 4
u1[4]、u1[9]が反応しています。ただし、4の方が大きく反応しています。
どのように判断しているかよくわかりました。
中間層数-1
次は、中間層が1の場合です。
バイアスは、省略しています。
実行
パラメータを以下とします。
# 学習回数
epoch = 50
# ノード数
mds = [[100]]
# 学習率
lrs = [0.5]
# バッチサイズ
batch_sizes = [100]
50エポック後、学習データに対する正解率は、約100%、テストデータに対する正解率は、約98%です。
重さ
50エポック後の重さを図にしてみましょう。
1層目は、$ 784 \times 100 $個、2層目は、$ 100 \times 10 $個です。
1層目の100個を$ 28 \times 28$でreshapeします。
2層目の10個を$ 10 \times 10$でreshapeします。
1階層目です。
import matplotlib.pyplot as plt
plt.figure(figsize=(20,21))
for i in range(100):
plt.subplot(10,10,i+1)
plt.title("W[1][" + str(i) + "]")
plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
いろいろな形で特徴を抽出しようとしているのでしょうか?見た目だと具体的に何を抽出しようとしているかわかりません。
2階層目です。
import matplotlib.pyplot as plt
plt.figure(figsize=(20,5))
for i in range(10):
plt.subplot(1,10,i+1)
plt.title("W[2][" + str(i) + "]")
plt.imshow(W[2][:,i].reshape(10,10), "bwr", vmin = -2, vmax = 2)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
こちらは、何なのかまったくわかりませんね。
どの数字がどれに影響しているか見てみます。2階層目の重みの大きい順に10個と小さい順に10個を表示してみます。
for i in range(10):
Wsorta = np.argsort(W[2][:,i])
Wsortd = Wsorta[::-1]
print(i)
# top 10
plt.figure(figsize=(20,21))
for j in range(10):
plt.subplot(1,10,j+1)
plt.title("W1[{0:2d}] {1:.2f}".format(Wsortd[j],W[2][Wsortd[j],i]))
plt.imshow(W[1][:,Wsortd[j]].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.show()
# lower 10
plt.figure(figsize=(20,21))
for j in range(10):
plt.subplot(1,10,j+1)
plt.title("W1[{0:2d}] {1:.2f}".format(Wsorta[j],W[2][Wsorta[j],i]))
plt.imshow(W[1][:,Wsorta[j]].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.show()
- 0
影響上位10位
これに当てはまると0の可能性が高い。
影響下位10位
これに当てはまると0の可能性が低い。
- 1
影響上位10位
これに当てはまると1の可能性が高い。
影響下位10位
これに当てはまると1の可能性が低い。
やはり、ぱっと見よくわからないですね。
u,z
テストデータの先頭5データについて、u1,z1を表示してみます。
u1
import matplotlib.pyplot as plt
plt.figure(figsize=(10,11))
for j in range(5):
plt.subplot(1,5,j+1)
plt.title(str(j) + "-u[1]")
plt.imshow(u_test[1][j,:].reshape(10,10), "bwr", vmin = -5, vmax = 5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
正解は、7,2,1,0,4の順です。
これだけ見てもわからないですね。
z1
import matplotlib.pyplot as plt
plt.figure(figsize=(10,11))
for j in range(5):
plt.subplot(1,5,j+1)
plt.title(str(j) + "-z[1]")
plt.imshow(z_test[1][j,:].reshape(10,10), "bwr", vmin = -5, vmax = 5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
活性化関数として、ReLUを利用していますので、負(青色)が0になります。
1階層目の重さとデータを掛けた図を示します。テストデータの5番目まで表示してみます。
import matplotlib.pyplot as plt
for j in range(5):
print(j)
plt.figure(figsize=(20,21))
for i in range(100):
plt.subplot(10,10,i+1)
plt.title("u[1][{0:2d}]({1:.2f})".format(i, u_test[1][j,i]))
plt.imshow((W[1][:,i]*nx_test[j]).reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
テストデータの1データ目です。
- 7
u2
import matplotlib.pyplot as plt
for j in range(5):
print(j)
plt.figure(figsize=(20,21))
for i in range(10):
plt.subplot(1,10,i+1)
plt.title("u[2][{0:d}]({1:.2f})".format(i, u_test[2][j,i]))
plt.imshow((W[2][:,i]*z_test[1][j]).reshape(10,10), "bwr", vmin = -2, vmax = 2)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
テストデータの1データ目です。
- 7
7が赤が多く強く反応しています。
- 2
他の数字も同様です。
中間層 784ノード
中間層を784ノードとして実行してみます。
# 学習回数
epoch = 50
# ノード数
mds = [[784]]
# 学習率
lrs = [0.9]
# バッチサイズ
batch_sizes = [100]
50エポック後には、学習データに対する正解率 100%、テストデータに対する正解率 約98.4%です。
重さの図です。
- 1階層目
スペースの関係からタイトルを省略しています。
import matplotlib.pyplot as plt
plt.figure(figsize=(20,21))
for i in range(784):
plt.subplot(28,28,i+1)
#plt.title("W[1][" + str(i) + "]")
plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.25, vmax = 0.25)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
- 2階層目
import matplotlib.pyplot as plt
plt.figure(figsize=(20,5))
for i in range(10):
plt.subplot(1,10,i+1)
plt.title("W[2][" + str(i) + "]")
plt.imshow(W[2][:,i].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
何を抽出しているのか、やはりわかりませんね。
中間層数-2
中間層を100,100ノードとして実行してみます。
# 学習回数
epoch = 50
# ノード数
mds = [[100,100]]
# 学習率
lrs = [0.5]
# バッチサイズ
batch_sizes = [100]
50エポック後には、学習データに対する正解率 100%、テストデータに対する正解率 約98.2%です。
重さの図です。
- 1階層目
import matplotlib.pyplot as plt
plt.figure(figsize=(20,21))
for i in range(100):
plt.subplot(10,10,i+1)
plt.title("W[1][" + str(i) + "]")
plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.25, vmax = 0.25)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
- 2階層目
import matplotlib.pyplot as plt
plt.figure(figsize=(20,21))
for i in range(100):
plt.subplot(10,10,i+1)
plt.title("W[2][" + str(i) + "]")
plt.imshow(W[2][:,i].reshape(10,10), "bwr", vmin = -0.5, vmax = 0.5)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
- 3階層目
import matplotlib.pyplot as plt
plt.figure(figsize=(20,5))
for i in range(10):
plt.subplot(1,10,i+1)
plt.title("W[3][" + str(i) + "]")
plt.imshow(W[3][:,i].reshape(10,10), "bwr", vmin = -1, vmax = 1)
plt.xticks([])
plt.yticks([])
#plt.colorbar()
plt.show()
複雑になるばかりです。
学習の中身を覗いてみましたが、いかがでしたでしょうか?
やはり、説明はできないが、分類ができるということでしょうか、