謎だった場所
なぜself.paramsにWを代入する際にW(numpy配列)をリストにしてから、わざわざ「W,=」とまたリストの中のself.paramの要素(numpy配列)を取り出すという一見無駄なことをしているのか?
class Embedding:
def __init__(self, W):
self.params = [W]
self.grads = [np.zeros_like(W)]
self.idx = None
def forward(self, idx):
W, = self.params
self.idx = idx
out = W[idx]
return out
【参考】
関数名にカンマを使用する意味についての内容が「W,=」の箇所の理解にとても役立ちました(ありがとうございます!)
その理由は、実は以下のように最初から最後までWをnumpy配列で処理をしても「Embeddingの処理だけなら」動作するが、その後の処理でリストの足し算などが行われるところでエラーとなり全体的なプログラムとしては動作しなくなるためだと理解しました。
# 以下でも「Embeddingの処理だけなら」動作する
class Embedding:
def __init__(self, W):
self.params = W
self.grads = [np.zeros_like(W)]
self.idx = None
def forward(self, idx):
W = self.params
self.idx = idx
out = W[idx]
return out
つまり、以下のコードでも
import numpy as np
class Embedding:
def __init__(self, W):
self.params = [W] #オリジナル
self.grads = [np.zeros_like(W)]
self.idx = None
def forward(self, idx):
W, = self.params #オリジナル
self.idx = idx
out = W[idx]
return out
W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")
layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")
params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")
params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")
W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")
W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))
以下のコードでも
import numpy as np
class Embedding:
def __init__(self, W):
self.params = W #リスト化しない
self.grads = [np.zeros_like(W)]
self.idx = None
def forward(self, idx):
W = self.params #numpy配列のまま処理
self.idx = idx
out = W[idx]
return out
W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")
layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")
params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")
params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")
W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")
W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))
結果は同じく以下のようになる。
W_in =
[[-0.00414995 -0.00505246 -0.02271379]
[-0.00385737 -0.01022162 -0.00621947]
[-0.01317972 0.00763595 0.00437246]
[ 0.01119065 -0.01144209 0.02539131]
[-0.00316145 -0.01609291 -0.00868459]
[ 0.00361989 0.01507116 -0.00318975]
[ 0.00530743 0.00881439 -0.01096747]]
3番目をEmbeddingしたもの:
[ 0.01119065 -0.01144209 0.02539131]
W_inをリストにしない場合:
[[-0.00414995 -0.00505246 -0.02271379]
[-0.00385737 -0.01022162 -0.00621947]
[-0.01317972 0.00763595 0.00437246]
[ 0.01119065 -0.01144209 0.02539131]
[-0.00316145 -0.01609291 -0.00868459]
[ 0.00361989 0.01507116 -0.00318975]
[ 0.00530743 0.00881439 -0.01096747]]
<class 'numpy.ndarray'>
W_inをリストにした場合:
[array([[-0.00414995, -0.00505246, -0.02271379],
[-0.00385737, -0.01022162, -0.00621947],
[-0.01317972, 0.00763595, 0.00437246],
[ 0.01119065, -0.01144209, 0.02539131],
[-0.00316145, -0.01609291, -0.00868459],
[ 0.00361989, 0.01507116, -0.00318975],
[ 0.00530743, 0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>
カンマつけた渡し方をしない場合(W = params)のW:
[array([[-0.00414995, -0.00505246, -0.02271379],
[-0.00385737, -0.01022162, -0.00621947],
[-0.01317972, 0.00763595, 0.00437246],
[ 0.01119065, -0.01144209, 0.02539131],
[-0.00316145, -0.01609291, -0.00868459],
[ 0.00361989, 0.01507116, -0.00318975],
[ 0.00530743, 0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>
カンマつけた渡し方をした場合(W, = params)のW:
[[-0.00414995 -0.00505246 -0.02271379]
[-0.00385737 -0.01022162 -0.00621947]
[-0.01317972 0.00763595 0.00437246]
[ 0.01119065 -0.01144209 0.02539131]
[-0.00316145 -0.01609291 -0.00868459]
[ 0.00361989 0.01507116 -0.00318975]
[ 0.00530743 0.00881439 -0.01096747]]
<class 'numpy.ndarray'>
しかし後者だとtrain.pyなどを動かしているとエラーが出て止まってしまう。それはおそらく処理の中で複数要素を一つのリストに放り込んでいく処理のところで実施されるリストの足し算が出来ない(numpy配列とリストの足し算が出来ない)ためである。
イメージ的には、以下でprint(a+c)が実行できないのと同じ。
a = np.array([[1,2,3,4,5],[6,7,8,9,0]])
b = np.arange(10).reshape(2,5)
print(a)
print(b)
print(a+b) #numpy配列の要素の数が合っているので計算できる
lista = [a]
listb = [b]
print(lista+listb)
c = np.arange(30).reshape(6,5)
print(a)
print(c)
# print(a+c) #numpy配列の要素の数が合っていないので計算できない
lista = [a]
listc = [c]
print(lista+listc) #リスト同士にするとリストの0番目にa、1番目にbが入るので、足し算が出来る
一応出力を書いておくとこんな感じ
# print(a)
[[1 2 3 4 5]
[6 7 8 9 0]]
# print(b)
[[0 1 2 3 4] #print(b)
[5 6 7 8 9]]
# print(a+b)
[[ 1 3 5 7 9]
[11 13 15 17 9]]
# print(lista+listb)
[array([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 0]]), array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])]
# print(a)
[[1 2 3 4 5]
[6 7 8 9 0]]
# print(c)
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]
[25 26 27 28 29]]
# print(lista+listc)
[array([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 0]]), array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]])]