前置き
Pythonで機械学習とか深層学習をやってます。Python歴は3年目くらい。まだまだ日々勉強です。
普段開発したりするときPytorchとかフレームワークのコードをよく眺めている訳なんですが、たまに気になっている事がありました。
なんでデフォルトの引数をNoneにして、その後何かしらの値を代入するのか?初めから何かしらの値をデフォルト値にすればいいじゃないかと。
例えば、Pytorch
のResnet
実装に関する以下のコードとか。
class BasicBlock(nn.Module):
def __init__(
self,
...(その他引数省略)
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
...(実装続く)
norm_layer
をデフォルト値の引数ではNone
で受けて、__init__
本体の中で、norm_layer = nn.BatchNorm2d
としています。
これ引数のデフォルト値をnn.BatchNorm2d
にすれば良くない?と言うのが私の気になっていたポイントでした。
たまたまEffective Python
を読んでいたら、そのヒントが合ったのでシェアしたいと思います!
結論
Q. 動的な引数のデフォルト値をなぜNoneにした方が良いのか?
A. デフォルト値は実行時、モジュールがロードされた時ただ1度だけ評価されるため、動的な値に奇妙な振る舞いをもたらす可能性があるため
(参照: Effective Python ― Pythonプログラムを改良する59項目 p.48 一部改変)
これだけだとよく分からん!って感じなので以下に例で解説します。
例
data
を受け取り、json形式であればそれを辞書型に変換して、jsonでなければ空の辞書を返す関数を考えてみます。
ダメな例
default
のデフォルト値に{}
を指定しています。
def bad_decode(data, default={}):
try:
return json.loads(data)
except ValueError:
return default
foo = bad_decode('bad data')
foo['staff'] = 5
bar = bad_decode('also bad data')
bar['meep'] = 1
print("foo", foo)
print("bar", bar)
以下のような出力になると思いますよね?
foo {'staff': 5}
bar {'meep': 1}
でもこれの結果は以下のような予想外な結果になります。ええぇ...
foo {'staff': 5, 'meep': 1}
bar {'staff': 5, 'meep': 1}
良い例
最初に挙げたPytorchの例と同じように、本体の中でdefault = {}
としています。
def good_decode(data, default=None):
if default is None:
default = {}
try:
return json.loads(data)
except ValueError:
return default
foo = good_decode('bad data')
foo['staff'] = 5
bar = good_decode('also bad data')
bar['meep'] = 1
print("foo", foo)
print("bar", bar)
こうする事で期待通りの出力になります。
foo {'staff': 5}
bar {'meep': 1}
番外編
最初に挙げたPytorchの例はclassだったので、classでの挙動も見てみます。
結論は、やはりデフォルト値はNoneにしないと予期せぬ挙動が起こります。
class BadDecode:
def __init__(self, data, default={}):
try:
self.data = json.loads(data)
except ValueError:
self.data = default
class GoodDecode:
def __init__(self, data, default=None):
if default is None:
default = {}
try:
self.data = json.loads(data)
except ValueError:
self.data = default
bad_foo = BadDecode('bad data')
bad_foo.data['staff'] = 5
bad_bar = BadDecode('also bad data')
bad_bar.data['meep'] = 1
good_foo = GoodDecode('bad data')
good_foo.data['staff'] = 5
good_bar = GoodDecode('also bad data')
good_bar.data['meep'] = 1
print(bad_foo.data)
print(bad_bar.data)
print(good_foo.data)
print(good_bar.data)
{'staff': 5, 'meep': 1}
{'staff': 5, 'meep': 1}
{'staff': 5}
{'meep': 1}
最後に
たまにはEffective Python
のような書籍を読み返してみるのも新たな発見があって面白いですね。