0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ゼロから作るDeep Learning② word2vecコードで謎だった場所

Last updated at Posted at 2021-07-22

謎だった場所

なぜself.paramsにWを代入する際にW(numpy配列)をリストにしてから、わざわざ「W,=」とまたリストの中のself.paramの要素(numpy配列)を取り出すという一見無駄なことをしているのか?

class Embedding:
    def __init__(self, W):
        self.params = [W]
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W, = self.params
        self.idx = idx
        out = W[idx]
        return out

【参考】
関数名にカンマを使用する意味についての内容が「W,=」の箇所の理解にとても役立ちました(ありがとうございます!)

その理由は、実は以下のように最初から最後までWをnumpy配列で処理をしても「Embeddingの処理だけなら」動作するが、その後の処理でリストの足し算などが行われるところでエラーとなり全体的なプログラムとしては動作しなくなるためだと理解しました。

# 以下でも「Embeddingの処理だけなら」動作する
class Embedding:
    def __init__(self, W):
        self.params = W
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W = self.params
        self.idx = idx
        out = W[idx]
        return out

つまり、以下のコードでも

import numpy as np

class Embedding:
    def __init__(self, W):
        self.params = [W]   #オリジナル
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W, = self.params   #オリジナル
        self.idx = idx
        out = W[idx]
        return out


W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")

layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")

params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")

params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")

W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")

W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))

以下のコードでも

import numpy as np

class Embedding:
    def __init__(self, W):
        self.params = W   #リスト化しない
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W = self.params   #numpy配列のまま処理
        self.idx = idx
        out = W[idx]
        return out


W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")

layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")

params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")

params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")

W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")

W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))

結果は同じく以下のようになる。

W_in =
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]

3番目をEmbeddingしたもの
[ 0.01119065 -0.01144209  0.02539131]

W_inをリストにしない場合:
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]
<class 'numpy.ndarray'>


W_inをリストにした場合:
[array([[-0.00414995, -0.00505246, -0.02271379],
       [-0.00385737, -0.01022162, -0.00621947],
       [-0.01317972,  0.00763595,  0.00437246],
       [ 0.01119065, -0.01144209,  0.02539131],
       [-0.00316145, -0.01609291, -0.00868459],
       [ 0.00361989,  0.01507116, -0.00318975],
       [ 0.00530743,  0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>


カンマつけた渡し方をしない場合(W = params)のW:
[array([[-0.00414995, -0.00505246, -0.02271379],
       [-0.00385737, -0.01022162, -0.00621947],
       [-0.01317972,  0.00763595,  0.00437246],
       [ 0.01119065, -0.01144209,  0.02539131],
       [-0.00316145, -0.01609291, -0.00868459],
       [ 0.00361989,  0.01507116, -0.00318975],
       [ 0.00530743,  0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>


カンマつけた渡し方をした場合(W, = params)のW:
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]
<class 'numpy.ndarray'>

しかし後者だとtrain.pyなどを動かしているとエラーが出て止まってしまう。それはおそらく処理の中で複数要素を一つのリストに放り込んでいく処理のところで実施されるリストの足し算が出来ない(numpy配列とリストの足し算が出来ない)ためである。
イメージ的には、以下でprint(a+c)が実行できないのと同じ。

a = np.array([[1,2,3,4,5],[6,7,8,9,0]])
b = np.arange(10).reshape(2,5)

print(a)
print(b)

print(a+b) #numpy配列の要素の数が合っているので計算できる

lista = [a]
listb = [b]

print(lista+listb)

c = np.arange(30).reshape(6,5)

print(a)
print(c)

# print(a+c) #numpy配列の要素の数が合っていないので計算できない

lista = [a]
listc = [c]

print(lista+listc) #リスト同士にするとリストの0番目にa、1番目にbが入るので、足し算が出来る

一応出力を書いておくとこんな感じ

# print(a)
[[1 2 3 4 5]   
 [6 7 8 9 0]]

# print(b)
[[0 1 2 3 4]   #print(b)
 [5 6 7 8 9]]

# print(a+b) 
[[ 1  3  5  7  9]   
 [11 13 15 17  9]]

# print(lista+listb)
 [array([[1, 2, 3, 4, 5],
       [6, 7, 8, 9, 0]]), array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])]

# print(a)
[[1 2 3 4 5]
 [6 7 8 9 0]]

# print(c)
 [[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]]

# print(lista+listc)
 [array([[1, 2, 3, 4, 5],
       [6, 7, 8, 9, 0]]), array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]])]
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?