2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pythonの実装を100倍効率化する :ChatGPTによる型アノテーション機能の活用

Posted at

はじめに

この記事は何?

pythonリーディングのお助けアイデアを提供します。

皆さーん、python、読みやすいですか????
簡潔に書ける一方で、複雑な処理になるとクラスや関数の挙動が予測しづらい、、、なーんてことはあるあるではないでしょうか。

分かりにくいコードはLLMを利用して型のヒント(型アノテーション)をもらおう というのが今回のTopicです。

型アノテーションって何?

変数や関数、クラスなどの要素に対して、その型情報を明示的に付与する手法のことです。
百聞は一見に如かず、実際に例を見てみましょう。

型アノテーションの基本

型アノテーションってどう書くの?

まずは型アノテーションのないコードを見てみましょう

# 変数
x = 5
y = 3.14
name = "John"

# 関数
def add(a, b):
    return a + b

# クラス
class Person:
    def __init__(self, name, age):
        self.name = name
        self.age = age

では次に型アノテーションを付与したバージョンを見てみましょう。

型アノテーションの例
# 変数の型アノテーションの書き方
x: int = 5
y: float = 3.14
name: str = "John"

# 関数の型アノテーションの書き方
def add(a: int, b: int) -> int:
    return a + b

# クラスの型アノテーションの書き方
class Person:
    def __init__(self, name: str, age: int):
        self.name = name
        self.age = age

それぞれのオブジェクトがどんな型を持っているのかが明示されました。
これによって変数名で推定しなくても、変数の型、関数・クラスの入出力が分かりやすくなりました。

pythonの型アノテーションは型を記入するだけで強制するわけではありません。(型が間違っていても動作します。)

この記事では分かりやすさのために、厳密な説明を省略します。
さらに詳細は下記のパブリックドキュメントを参照ください。

ChatGPTによる型アノテーション機能の実践

ChatGPTを使った型アノテーションの具体例

では実際によくわからない複雑そうなコードを持ってきてコードリーディングを行ってみましょう。
今回は、画像分類のモデルにおいて2024年2月現在ImageNetでSOTA(state-of-the-art : 最高性能)であるモデルを実際に理解しながら実装したいとします。

今回は実装が複雑になりがちな深層学習を例に持ってきますが、
基本的にやり方はすべて同じです。

概要を読むと、途中で「ノイズ」を加えてVision transformerを実装していると書いてあるのですが、どのような実装になっているか想像ができないですね。
では実際に中身を見てみましょう。

def quality_matrix(k, alpha=0.3):
    """r
    Quality matrix Q. Described in the eq (17) so that eps = QX, where X is the input. 
    Alpha is 0.3, as mentioned in Appendix D.
    """
    identity = torch.diag(torch.ones(k))
    shift_identity = torch.zeros(k, k)
    for i in range(k):
        shift_identity[(i+1)%k, i] = 1
    opt = -alpha * identity + alpha * shift_identity
    return opt

def optimal_quality_matrix(k):
    """r
    Optimal Quality matrix Q. Described in the eq (19) so that eps = QX, where X is the input. 
    Suppose 1_(kxk) is torch.ones
    """
    return torch.diag(torch.ones(k)) * -k/(k+1) + torch.ones(k, k) / (k+1)

class NoisyViT(VisionTransformer):
    """r
    Args:
        optimal: Determine the linear transform noise is produced by the quality matrix or the optimal quality matrix.
        res: Inference resolution. Ensure the aspect ratio = 1
    
    """
    def __init__(self, optimal: bool, res: int, **kwargs):
        self.stage3_res = res // 16
        if optimal:
            linear_transform_noise = optimal_quality_matrix(self.stage3_res)
        else:
            linear_transform_noise = quality_matrix(self.stage3_res)
        super().__init__(**kwargs)
        self.linear_transform_noise = torch.nn.Parameter(linear_transform_noise, requires_grad=False)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        if self.grad_checkpointing and not torch.jit.is_scripting():
            return super().forward_features(x)
        
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        # Add noise when training/testing
        # See https://openreview.net/forum?id=Ce0dDt9tUT for more detail
        x = self.blocks[:-1](x)
        # Suppose the token dim = 1
        token = x[:, 0, :].unsqueeze(1)
        x = x[:, 1:, :].permute(0, 2, 1)
        B, C, L = x.shape
        x = x.reshape(B, C, self.stage3_res, self.stage3_res)
        x = self.linear_transform_noise@x + x
        x = x.flatten(2).transpose(1, 2).contiguous()
        x = torch.cat([token, x], dim=1)
        x = self.blocks[-1](x)

        x = self.norm(x)
        return x

う~~~ん、分からない!

では一つずつ関数に型アノテーションを付与して、中身を見ていきましょう。

ChatGPTへのプロンプト
下記コードに型アノテーションを付与してください。
コメントアウトは日本語訳してください

~~~~~~~~~~~~
実際のコード
~~~~~~~~~~~~
型アノテーションの例
def quality_matrix(k: int, alpha: float = 0.3) -> torch.Tensor:
    """品質行列 Q. 入力 X に対して eps = QX となる。Appendix D での alpha は 0.3。"""
    identity: torch.Tensor = torch.diag(torch.ones(k)) #torch.diag():対角行列
    shift_identity: torch.Tensor = torch.zeros(k, k)
    for i in range(k):
        shift_identity[(i+1)%k, i] = 1
    opt: torch.Tensor = -alpha * identity + alpha * shift_identity
    return opt

入力は整数と浮動小数点を受け取り、出力はテンソルのようですね。下記の部分で明示されます。

(k: int, alpha: float = 0.3) -> torch.Tensor:

kやalphaといった汎用的な引数がどのような型なのか分かるのは便利ですね。

LLMで推論しているだけなので型は下記の3つの要素から決定されています。

  • 実装から自明である。
  • 変数から推定される。
  • 類似したコード事例が多数ある。

そのため、デタラメな実装だと正しい型アノテーションは困難なのですが、リファレンスにしたいようなコードで低品質というのは考えづらい状況なので心配無用でしょう。

引数kは整数型で、これは行列の大きさを決めていて、alphaは浮動小数点型でこれは行列の和の時に係数のように使っていることが分かります。
一つ目のテンソルは対角行列、
二つ目のテンソルは対角成分を右に1ずらした行列ですね。
この二つのテンソルにalpha(浮動小数点)をかけて足していることが分かります。

入出力と、中身の変数の型が分かるだけで内部処理が見やすくなりました。

次の関数も短いですが見てみましょう。

型アノテーションの例
def optimal_quality_matrix(k: int) -> torch.Tensor:
    """最適な品質行列 Q. 入力 X に対して eps = QX となる。1_(kxk) が torch.ones であると仮定。"""
    return torch.diag(torch.ones(k)) * -k/(k+1) + torch.ones(k, k) / (k+1)

対角行列成分が-k/(k+1)となる行列をつくって、全ての要素がk(k+1)となる正方行列を足しました。

最後に順伝搬のコードを見てみましょう。

型アノテーションの例
class NoisyViT(VisionTransformer):
    """ノイズ付き ViT(ビジョン Transformer)。

    Args:
        optimal: 線形変換ノイズが品質行列または最適な品質行列によって生成されるかを決定します。
        res: 推論解像度。アスペクト比 = 1 を確保します。
    """
    def __init__(self, optimal: bool, res: int, **kwargs):
        self.stage3_res: int = res // 16
        if optimal:
            linear_transform_noise: torch.Tensor = optimal_quality_matrix(self.stage3_res)
        else:
            linear_transform_noise: torch.Tensor = quality_matrix(self.stage3_res)
        super().__init__(**kwargs)
        self.linear_transform_noise: torch.nn.Parameter = torch.nn.Parameter(linear_transform_noise, requires_grad=False)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        if self.grad_checkpointing and not torch.jit.is_scripting():
            return super().forward_features(x)
        
        x: torch.Tensor = self.patch_embed(x)
        x: torch.Tensor = self._pos_embed(x)
        x: torch.Tensor = self.patch_drop(x)
        x: torch.Tensor = self.norm_pre(x)
        # Add noise when training/testing
        # 詳細については https://openreview.net/forum?id=Ce0dDt9tUT を参照してください
        x: torch.Tensor = self.blocks[:-1](x)
        # トークンの次元が 1 であると仮定します
        token: torch.Tensor = x[:, 0, :].unsqueeze(1)
        x: torch.Tensor = x[:, 1:, :].permute(0, 2, 1)
        B, C, L = x.shape
        x: torch.Tensor = x.reshape(B, C, self.stage3_res, self.stage3_res)
        x: torch.Tensor = self.linear_transform_noise @ x + x
        x: torch.Tensor = x.flatten(2).transpose(1, 2).contiguous()
        x: torch.Tensor = torch.cat([token, x], dim=1)
        x: torch.Tensor = self.blocks[-1](x)

        x: torch.Tensor = self.norm(x)
        return x

量が多いですが慌てることはありません。
クラスの引数のoptionalはブーリアン型で先ほど作った行列のどちらを採用するか決めています。
出てくる変数のほとんどはpytorchのテンソル型ですね。
そして分からなかった「ノイズ」とはさっき作った行列(テンソル)のことで、途中の線形変換でこのノイズを嚙ましているとわかりました。
下記の部分です。(正確にはスキップコネクションという機構を使っています。)

x: torch.Tensor = self.linear_transform_noise @ x + x

このように一見難しいコードでも、
型が分かれば処理内容が推察しやすくなります。

たしかにpytorchのモジュールをオーバーライドしているようなクラスではどんな実装が行われるか予想がつくケースもありますが、
データの前処理など書き手によって処理が様々である箇所は型アノテーションをしてざっと眺めるというのが早いです。

まとめ

型アノテーション結果のまとめ

型アノテーションを通して深層学習のコードリーディングを行いました。
低コストで可読性をあげてくれるため、所見のコードを見るのに役立ちます。

日々の実装ではコードリーディングし、型アノテーションをした箇所はコメントを付け加えて下記のブログにまとめています。
markdownで内容を整理することで階層的に情報を整理できるのと、
過去に実装したことの復習になるのでお勧めです。

宣伝

また、先日twitterも始めました。
(帰社後に家で寝る前に技術記事書いてる人間、"信用"できませんか?)
楽しい楽しい機械学習やpythonの話題に穏便に触れるとても平和なアカウントにするので覗いてやってください。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?