この記事は京都大学人工知能研究会KaiRA Advent Calendar 5日目の記事です。
conv2d
のカスタム関数を作成している時にbackward処理でconv_transpose2d
をどう使えば良いか困ったので、conv_transpose2d
の仕様を備忘録としてまとめます。
公式ドキュメントはこちら。
動作イメージ
ConvTranspose2dのドキュメントにも記載の通り、conv_transpose2d
はconv2d
の勾配計算に使えます。すなわち、y = conv2d(x, w)
とした時にx
の勾配はconv_transpose2d(y.grad, w)
で計算できます。
y = conv_transpose2d(x, w)
とした時の動作は以下のようになります。左側がx
、真ん中がw
、右側がy
です。x
の赤ピクセルにw
の青ピクセルをかけた値がy
の緑ピクセルに足されていきます。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F629111%2F95282a09-eda2-b104-a537-2b0df1302545.gif?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=dff9dfaa5389e2f7bc2e3457bb84abb2)
以下設定パラメータについてまとめます。
stride
出力先の座標をstride
の数ずつスライドさせます。デフォルトでは1
です。
y = conv_transpose2d(x, w, stride=2)
とした時の動作は以下のようになります。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F629111%2F55664220-185a-181f-8df1-f0323f585376.gif?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=934ccfbda5ce8a7d7a98d754af876ecc)
padding
計算結果から、padding
分だけ切り取られます。デフォルトでは0
です。
y = conv_transpose2d(x, w, stride=2, padding=1)
とした時の動作は以下のようになります。出力は中央の4ピクセルだけになります。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F629111%2F1691c7f5-afcb-2fc2-becc-6c113fd16c37.png?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=39100a6f5df0fe00925442dcd9dd738b)
output_padding
計算結果に対して、右側と下側にzero paddingを行います。デフォルトでは0
です。
y = conv_transpose2d(x, w, stride=2, output_padding=1)
とした時の動作は以下のようになります。padding
も1以上に指定した場合は、output_paddingが先に適用されます。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F629111%2Fd7cf8db9-0a1a-a206-ccdc-fb75b0f0031c.png?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=701ac615b0b4b1319c1642165c28d3fc)
dilation
出力先の座標がdilation
飛ばしになります。デフォルトでは1
です。
y = conv_transpose2d(x, w, dilation=2)
とした時の動作は以下のようになります。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F629111%2F1ce8f521-6e47-93f4-1d88-8dbb1fb21739.gif?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=023d215c12046dbe3ec84e20e093401b)
conv2dとの対応
(ドキュメントにも書いていることですが…)
冒頭にも述べた通り、y = conv2d(x, w)
とした時にconv_transpose2d(y.grad, w)
としてx
の勾配を計算できます。この時、stride
、padding
、dilation
が両方で同じ値に設定されている必要があります。また、conv2d
を計算する時にstride
が1より大きい場合には端ピクセルが切り捨てられる場合があるため、そうした場合に勾配のサイズが合うようにoutput_padding
が設定されます。