はじめに
この記事は何?
pythonリーディングのお助けアイデアを提供します。
皆さーん、python、読みやすいですか????
簡潔に書ける一方で、複雑な処理になるとクラスや関数の挙動が予測しづらい、、、なーんてことはあるあるではないでしょうか。
分かりにくいコードはLLMを利用して型のヒント(型アノテーション)をもらおう というのが今回のTopicです。
型アノテーションって何?
変数や関数、クラスなどの要素に対して、その型情報を明示的に付与する手法のことです。
百聞は一見に如かず、実際に例を見てみましょう。
型アノテーションの基本
型アノテーションってどう書くの?
まずは型アノテーションのないコードを見てみましょう
# 変数
x = 5
y = 3.14
name = "John"
# 関数
def add(a, b):
return a + b
# クラス
class Person:
def __init__(self, name, age):
self.name = name
self.age = age
では次に型アノテーションを付与したバージョンを見てみましょう。
# 変数の型アノテーションの書き方
x: int = 5
y: float = 3.14
name: str = "John"
# 関数の型アノテーションの書き方
def add(a: int, b: int) -> int:
return a + b
# クラスの型アノテーションの書き方
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
それぞれのオブジェクトがどんな型を持っているのかが明示されました。
これによって変数名で推定しなくても、変数の型、関数・クラスの入出力が分かりやすくなりました。
pythonの型アノテーションは型を記入するだけで強制するわけではありません。(型が間違っていても動作します。)
この記事では分かりやすさのために、厳密な説明を省略します。
さらに詳細は下記のパブリックドキュメントを参照ください。
ChatGPTによる型アノテーション機能の実践
ChatGPTを使った型アノテーションの具体例
では実際によくわからない複雑そうなコードを持ってきてコードリーディングを行ってみましょう。
今回は、画像分類のモデルにおいて2024年2月現在ImageNetでSOTA(state-of-the-art : 最高性能)であるモデルを実際に理解しながら実装したいとします。
今回は実装が複雑になりがちな深層学習を例に持ってきますが、
基本的にやり方はすべて同じです。
概要を読むと、途中で「ノイズ」を加えてVision transformerを実装していると書いてあるのですが、どのような実装になっているか想像ができないですね。
では実際に中身を見てみましょう。
def quality_matrix(k, alpha=0.3):
"""r
Quality matrix Q. Described in the eq (17) so that eps = QX, where X is the input.
Alpha is 0.3, as mentioned in Appendix D.
"""
identity = torch.diag(torch.ones(k))
shift_identity = torch.zeros(k, k)
for i in range(k):
shift_identity[(i+1)%k, i] = 1
opt = -alpha * identity + alpha * shift_identity
return opt
def optimal_quality_matrix(k):
"""r
Optimal Quality matrix Q. Described in the eq (19) so that eps = QX, where X is the input.
Suppose 1_(kxk) is torch.ones
"""
return torch.diag(torch.ones(k)) * -k/(k+1) + torch.ones(k, k) / (k+1)
class NoisyViT(VisionTransformer):
"""r
Args:
optimal: Determine the linear transform noise is produced by the quality matrix or the optimal quality matrix.
res: Inference resolution. Ensure the aspect ratio = 1
"""
def __init__(self, optimal: bool, res: int, **kwargs):
self.stage3_res = res // 16
if optimal:
linear_transform_noise = optimal_quality_matrix(self.stage3_res)
else:
linear_transform_noise = quality_matrix(self.stage3_res)
super().__init__(**kwargs)
self.linear_transform_noise = torch.nn.Parameter(linear_transform_noise, requires_grad=False)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
if self.grad_checkpointing and not torch.jit.is_scripting():
return super().forward_features(x)
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
# Add noise when training/testing
# See https://openreview.net/forum?id=Ce0dDt9tUT for more detail
x = self.blocks[:-1](x)
# Suppose the token dim = 1
token = x[:, 0, :].unsqueeze(1)
x = x[:, 1:, :].permute(0, 2, 1)
B, C, L = x.shape
x = x.reshape(B, C, self.stage3_res, self.stage3_res)
x = self.linear_transform_noise@x + x
x = x.flatten(2).transpose(1, 2).contiguous()
x = torch.cat([token, x], dim=1)
x = self.blocks[-1](x)
x = self.norm(x)
return x
う~~~ん、分からない!
では一つずつ関数に型アノテーションを付与して、中身を見ていきましょう。
下記コードに型アノテーションを付与してください。
コメントアウトは日本語訳してください
~~~~~~~~~~~~
実際のコード
~~~~~~~~~~~~
def quality_matrix(k: int, alpha: float = 0.3) -> torch.Tensor:
"""品質行列 Q. 入力 X に対して eps = QX となる。Appendix D での alpha は 0.3。"""
identity: torch.Tensor = torch.diag(torch.ones(k)) #torch.diag():対角行列
shift_identity: torch.Tensor = torch.zeros(k, k)
for i in range(k):
shift_identity[(i+1)%k, i] = 1
opt: torch.Tensor = -alpha * identity + alpha * shift_identity
return opt
入力は整数と浮動小数点を受け取り、出力はテンソルのようですね。下記の部分で明示されます。
(k: int, alpha: float = 0.3) -> torch.Tensor:
kやalphaといった汎用的な引数がどのような型なのか分かるのは便利ですね。
LLMで推論しているだけなので型は下記の3つの要素から決定されています。
- 実装から自明である。
- 変数から推定される。
- 類似したコード事例が多数ある。
そのため、デタラメな実装だと正しい型アノテーションは困難なのですが、リファレンスにしたいようなコードで低品質というのは考えづらい状況なので心配無用でしょう。
引数kは整数型で、これは行列の大きさを決めていて、alphaは浮動小数点型でこれは行列の和の時に係数のように使っていることが分かります。
一つ目のテンソルは対角行列、
二つ目のテンソルは対角成分を右に1ずらした行列ですね。
この二つのテンソルにalpha(浮動小数点)をかけて足していることが分かります。
入出力と、中身の変数の型が分かるだけで内部処理が見やすくなりました。
次の関数も短いですが見てみましょう。
def optimal_quality_matrix(k: int) -> torch.Tensor:
"""最適な品質行列 Q. 入力 X に対して eps = QX となる。1_(kxk) が torch.ones であると仮定。"""
return torch.diag(torch.ones(k)) * -k/(k+1) + torch.ones(k, k) / (k+1)
対角行列成分が-k/(k+1)となる行列をつくって、全ての要素がk(k+1)となる正方行列を足しました。
最後に順伝搬のコードを見てみましょう。
class NoisyViT(VisionTransformer):
"""ノイズ付き ViT(ビジョン Transformer)。
Args:
optimal: 線形変換ノイズが品質行列または最適な品質行列によって生成されるかを決定します。
res: 推論解像度。アスペクト比 = 1 を確保します。
"""
def __init__(self, optimal: bool, res: int, **kwargs):
self.stage3_res: int = res // 16
if optimal:
linear_transform_noise: torch.Tensor = optimal_quality_matrix(self.stage3_res)
else:
linear_transform_noise: torch.Tensor = quality_matrix(self.stage3_res)
super().__init__(**kwargs)
self.linear_transform_noise: torch.nn.Parameter = torch.nn.Parameter(linear_transform_noise, requires_grad=False)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
if self.grad_checkpointing and not torch.jit.is_scripting():
return super().forward_features(x)
x: torch.Tensor = self.patch_embed(x)
x: torch.Tensor = self._pos_embed(x)
x: torch.Tensor = self.patch_drop(x)
x: torch.Tensor = self.norm_pre(x)
# Add noise when training/testing
# 詳細については https://openreview.net/forum?id=Ce0dDt9tUT を参照してください
x: torch.Tensor = self.blocks[:-1](x)
# トークンの次元が 1 であると仮定します
token: torch.Tensor = x[:, 0, :].unsqueeze(1)
x: torch.Tensor = x[:, 1:, :].permute(0, 2, 1)
B, C, L = x.shape
x: torch.Tensor = x.reshape(B, C, self.stage3_res, self.stage3_res)
x: torch.Tensor = self.linear_transform_noise @ x + x
x: torch.Tensor = x.flatten(2).transpose(1, 2).contiguous()
x: torch.Tensor = torch.cat([token, x], dim=1)
x: torch.Tensor = self.blocks[-1](x)
x: torch.Tensor = self.norm(x)
return x
量が多いですが慌てることはありません。
クラスの引数のoptionalはブーリアン型で先ほど作った行列のどちらを採用するか決めています。
出てくる変数のほとんどはpytorchのテンソル型ですね。
そして分からなかった「ノイズ」とはさっき作った行列(テンソル)のことで、途中の線形変換でこのノイズを嚙ましているとわかりました。
下記の部分です。(正確にはスキップコネクションという機構を使っています。)
x: torch.Tensor = self.linear_transform_noise @ x + x
このように一見難しいコードでも、
型が分かれば処理内容が推察しやすくなります。
たしかにpytorchのモジュールをオーバーライドしているようなクラスではどんな実装が行われるか予想がつくケースもありますが、
データの前処理など書き手によって処理が様々である箇所は型アノテーションをしてざっと眺めるというのが早いです。
まとめ
型アノテーション結果のまとめ
型アノテーションを通して深層学習のコードリーディングを行いました。
低コストで可読性をあげてくれるため、所見のコードを見るのに役立ちます。
日々の実装ではコードリーディングし、型アノテーションをした箇所はコメントを付け加えて下記のブログにまとめています。
markdownで内容を整理することで階層的に情報を整理できるのと、
過去に実装したことの復習になるのでお勧めです。
宣伝
また、先日twitterも始めました。
(帰社後に家で寝る前に技術記事書いてる人間、"信用"できませんか?)
楽しい楽しい機械学習やpythonの話題に穏便に触れるとても平和なアカウントにするので覗いてやってください。
【自己紹介】
— 狸里 狸丸 (@tanuki_boosting) January 28, 2024
機械学習を専門とするエンジニア。
Computer visionやMaterials Informatics 関連の業務をしています。
主に数学やプログラミング、機械学習のことについて呟きます。#python #機械学習