LoginSignup
3
4

More than 5 years have passed since last update.

ディープラーニングを実装から学ぶ~ (まとめ3)中では何が?

Posted at

学習中に中で何が起きているかちょっと覗いてみましょう。データは、引き続きMNISTです。

中間層-なし

まずは、中間層なしの場合を見てみましょう。
以下のネットワーク図になります。(バイアスは省略しています。)

network_[].png

実行

前回のプログラムで以下の設定にて実行します。

# 学習回数
epoch = 50
# ノード数
mds = [[]]
# 学習率
lrs = [0.1]
# バッチサイズ
batch_sizes = [100]

50エポック後、学習データに対する正解率は、約93%、テストデータに対する正解率は、約92.5%です。

重さ

50エポック後の重さを図にしてみましょう。
重さは、全部で$ 784 \times 10 $個あります。$ w_{1 i}^{(1)} $、$ w_{2 i}^{(1)} $の順に$ w_{10 i}^{(1)} $まで表示します。$ i $は、784個、$ 28 \times 28 $にreshapeし表示します。
正を赤、負を青で表します。色が濃いほど大きい値です。

import matplotlib.pyplot as plt

plt.figure(figsize=(20,6))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.title(str(i))
    plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -1, vmax = 1)
    plt.colorbar()
plt.show()

imshow_[]_W[50].png

何か見えてきますか?
0は、真ん中の青が目立ちます。真ん中に文字があれば負の要因となります。
1は、逆に真ん中が赤です。真ん中以外は青が多く負の要因となります。
2は、左側が青で、左側がつながっていないことが要因になります。
3以降も何をもとに数字を判断しているかわかりますね。

各エポックの重みを保存しておき、エポックごとにどのように変化していくか見てみましょう。
0の場合です。

imshow_[]_W[0-49].png

最初は、学習前の重さでランダムの値です。1エポック目から輪郭が見えています。

u

uの値を表示するため、u,zを取得できるように予測関数を変更します。

def predict2(x, W, b):
    # 層数
    layer = len(W) 
    # 順伝播
    u = {}
    z = {}
    u[1] = affine(x, W[1], b[1])
    for i in range(1, layer):
        z[i] = relu(u[i])
        u[i+1] = affine(z[i], W[i+1], b[i+1])
    y  = softmax(u[layer])
    return y, u, z

呼び出し元も変更します。

            y_train, u_train, z_train = predict2(nx_train, W, b)
            y_test, u_test, z_test = predict2(nx_test, W, b)

重さとデータを掛けた図を示します。テストデータの5番目まで表示してみます。

import matplotlib.pyplot as plt

for j in range(5):
    print(j)
    plt.figure(figsize=(20,7))
    for i in range(10):
        plt.subplot(2,5,i+1)
        plt.title("u1[{0:d}]={1:.2f}".format(i, u_test[1][j,i]))
        plt.imshow((W[1][:,i]*nx_test[j]).reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
        plt.colorbar()
    plt.show()

結果を順に見てみましょう。

  • 7

imshow_[]_u1_0.png

u1[7]の画像が一番多く赤く反応しています。

  • 2

imshow_[]_u1_1.png

u1[2]、u1[3]とも上部は同じように反応していますが、下部で差が出ているようです。

  • 1

imshow_[]_u1_2.png

u1[1]の中央が大きく反応しています。

  • 0

imshow_[]_u1_3.png

u1[0]が大きく反応しています。

  • 4

imshow_[]_u1_4.png

u1[4]、u1[9]が反応しています。ただし、4の方が大きく反応しています。

どのように判断しているかよくわかりました。

中間層数-1

次は、中間層が1の場合です。

network_[100].png

バイアスは、省略しています。

実行

パラメータを以下とします。

# 学習回数
epoch = 50
# ノード数
mds = [[100]]
# 学習率
lrs = [0.5]
# バッチサイズ
batch_sizes = [100]

50エポック後、学習データに対する正解率は、約100%、テストデータに対する正解率は、約98%です。

重さ

50エポック後の重さを図にしてみましょう。
1層目は、$ 784 \times 100 $個、2層目は、$ 100 \times 10 $個です。
1層目の100個を$ 28 \times 28$でreshapeします。
2層目の10個を$ 10 \times 10$でreshapeします。

1階層目です。

import matplotlib.pyplot as plt

plt.figure(figsize=(20,21))
for i in range(100):
    plt.subplot(10,10,i+1)
    plt.title("W[1][" + str(i) + "]")
    plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100]_W[1].png

いろいろな形で特徴を抽出しようとしているのでしょうか?見た目だと具体的に何を抽出しようとしているかわかりません。

2階層目です。

import matplotlib.pyplot as plt

plt.figure(figsize=(20,5))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.title("W[2][" + str(i) + "]")
    plt.imshow(W[2][:,i].reshape(10,10), "bwr", vmin = -2, vmax = 2)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100]_W[2].png

こちらは、何なのかまったくわかりませんね。

どの数字がどれに影響しているか見てみます。2階層目の重みの大きい順に10個と小さい順に10個を表示してみます。

for i in range(10):
    Wsorta = np.argsort(W[2][:,i])
    Wsortd = Wsorta[::-1]
    print(i)
    # top 10
    plt.figure(figsize=(20,21))
    for j in range(10):
        plt.subplot(1,10,j+1)
        plt.title("W1[{0:2d}] {1:.2f}".format(Wsortd[j],W[2][Wsortd[j],i]))
        plt.imshow(W[1][:,Wsortd[j]].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
    plt.show()
    # lower 10
    plt.figure(figsize=(20,21))
    for j in range(10):
        plt.subplot(1,10,j+1)
        plt.title("W1[{0:2d}] {1:.2f}".format(Wsorta[j],W[2][Wsorta[j],i]))
        plt.imshow(W[1][:,Wsorta[j]].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
    plt.show()
  • 0

影響上位10位

imshow_[100]_w1_0_top10.png

これに当てはまると0の可能性が高い。

影響下位10位

imshow_[100]_w1_0_lower10.png

これに当てはまると0の可能性が低い。

  • 1

影響上位10位

imshow_[100]_w1_1_top10.png

これに当てはまると1の可能性が高い。

影響下位10位

imshow_[100]_w1_1_lower10.png

これに当てはまると1の可能性が低い。

やはり、ぱっと見よくわからないですね。

u,z

テストデータの先頭5データについて、u1,z1を表示してみます。

u1

import matplotlib.pyplot as plt

plt.figure(figsize=(10,11))
for j in range(5):
    plt.subplot(1,5,j+1)
    plt.title(str(j) + "-u[1]")
    plt.imshow(u_test[1][j,:].reshape(10,10), "bwr", vmin = -5, vmax = 5)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100]_u1_0-4.png

正解は、7,2,1,0,4の順です。
これだけ見てもわからないですね。

z1

import matplotlib.pyplot as plt

plt.figure(figsize=(10,11))
for j in range(5):
    plt.subplot(1,5,j+1)
    plt.title(str(j) + "-z[1]")
    plt.imshow(z_test[1][j,:].reshape(10,10), "bwr", vmin = -5, vmax = 5)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100]_z1_0-4.png

活性化関数として、ReLUを利用していますので、負(青色)が0になります。

1階層目の重さとデータを掛けた図を示します。テストデータの5番目まで表示してみます。

import matplotlib.pyplot as plt

for j in range(5):
    print(j)
    plt.figure(figsize=(20,21))
    for i in range(100):
        plt.subplot(10,10,i+1)
        plt.title("u[1][{0:2d}]({1:.2f})".format(i, u_test[1][j,i]))
        plt.imshow((W[1][:,i]*nx_test[j]).reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
        plt.xticks([])
        plt.yticks([])
        #plt.colorbar()
    plt.show()

テストデータの1データ目です。

  • 7

imshow_[100]_u1_0.png

u2

import matplotlib.pyplot as plt

for j in range(5):
    print(j)
    plt.figure(figsize=(20,21))
    for i in range(10):
        plt.subplot(1,10,i+1)
        plt.title("u[2][{0:d}]({1:.2f})".format(i, u_test[2][j,i]))
        plt.imshow((W[2][:,i]*z_test[1][j]).reshape(10,10), "bwr", vmin = -2, vmax = 2)
        plt.xticks([])
        plt.yticks([])
        #plt.colorbar()
    plt.show()

テストデータの1データ目です。

  • 7

imshow_[100]_u2_0.png

7が赤が多く強く反応しています。

  • 2

imshow_[100]_u2_1.png

他の数字も同様です。

中間層 784ノード

中間層を784ノードとして実行してみます。

# 学習回数
epoch = 50
# ノード数
mds = [[784]]
# 学習率
lrs = [0.9]
# バッチサイズ
batch_sizes = [100]

50エポック後には、学習データに対する正解率 100%、テストデータに対する正解率 約98.4%です。

重さの図です。

  • 1階層目

スペースの関係からタイトルを省略しています。

import matplotlib.pyplot as plt

plt.figure(figsize=(20,21))
for i in range(784):
    plt.subplot(28,28,i+1)
    #plt.title("W[1][" + str(i) + "]")
    plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.25, vmax = 0.25)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[784]_W[1].png

  • 2階層目
import matplotlib.pyplot as plt

plt.figure(figsize=(20,5))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.title("W[2][" + str(i) + "]")
    plt.imshow(W[2][:,i].reshape(28,28), "bwr", vmin = -0.5, vmax = 0.5)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[784]_W[2].png

何を抽出しているのか、やはりわかりませんね。

中間層数-2

中間層を100,100ノードとして実行してみます。

# 学習回数
epoch = 50
# ノード数
mds = [[100,100]]
# 学習率
lrs = [0.5]
# バッチサイズ
batch_sizes = [100]

50エポック後には、学習データに対する正解率 100%、テストデータに対する正解率 約98.2%です。

重さの図です。

  • 1階層目
import matplotlib.pyplot as plt

plt.figure(figsize=(20,21))
for i in range(100):
    plt.subplot(10,10,i+1)
    plt.title("W[1][" + str(i) + "]")
    plt.imshow(W[1][:,i].reshape(28,28), "bwr", vmin = -0.25, vmax = 0.25)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100,100]_W[1].png

  • 2階層目
import matplotlib.pyplot as plt

plt.figure(figsize=(20,21))
for i in range(100):
    plt.subplot(10,10,i+1)
    plt.title("W[2][" + str(i) + "]")
    plt.imshow(W[2][:,i].reshape(10,10), "bwr", vmin = -0.5, vmax = 0.5)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100,100]_W[2].png

  • 3階層目
import matplotlib.pyplot as plt

plt.figure(figsize=(20,5))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.title("W[3][" + str(i) + "]")
    plt.imshow(W[3][:,i].reshape(10,10), "bwr", vmin = -1, vmax = 1)
    plt.xticks([])
    plt.yticks([])
    #plt.colorbar()
plt.show()

imshow_[100,100]_W[3].png

複雑になるばかりです。

学習の中身を覗いてみましたが、いかがでしたでしょうか?
やはり、説明はできないが、分類ができるということでしょうか、

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4