0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Burnのソースコードで学ぶ2D適応型平均プーリング層

Posted at

本記事ではRust製Burnフレームワークの内部実装を紐解きながら、畳み込みニューラルネットワーク(CNN)の重要な構成要素である「2D適応型平均プーリング層(Adaptive Average Pooling 2D)」について学びます。

プーリングとは何か

ニューラルネットワーク、特に畳み込みニューラルネットワーク(CNN)において、 プーリング(Pooling) は重要な操作の一つです。プーリングを理解するために、まず画像認識におけるCNNの処理の流れを簡単に見てみましょう。

プーリングの基本的な役割

CNNでは、入力画像に対して畳み込み層(Convolutional Layer)が特徴を抽出します。畳み込み層は、画像内のエッジ、テクスチャ、パターンなどの局所的な特徴を検出します。この処理によって「特徴マップ(Feature Map)」と呼ばれるデータが生成されます。

ここでプーリングの出番です。プーリングは特徴マップのサイズを縮小する処理で、以下のような役割があります:

  1. 次元削減: 特徴マップのサイズを小さくすることで、後続の層での計算量とパラメータ数を減らします
  2. 位置不変性の向上: 特徴の厳密な位置よりも、特徴の存在自体を重視するようになります
  3. 過学習の抑制: パラメータ数を減らすことで、モデルの汎化性能を向上させます

プーリングの種類

代表的なプーリング手法には以下のものがあります:

  • 最大プーリング(Max Pooling): 領域内の最大値を選択します。最も一般的に使われる手法で、最も強い特徴(活性化)を保持するのに適しています。
  • 平均プーリング(Average Pooling): 領域内の平均値を計算します。全体的な特徴の強さを表現するのに適しています。
  • グローバルプーリング(Global Pooling): 特徴マップ全体に対して一つの値を計算します。全結合層の前に使用されることが多いです。

プーリングの仕組み

プーリングは通常、以下のような手順で行われます:

  1. 特徴マップを小さな領域(通常2×2や3×3のウィンドウ)に分割します
  2. 各領域に対して、最大値や平均値などの集約操作を適用します
  3. 集約した値を出力の特徴マップの対応する位置に配置します

例えば、2×2の最大プーリングを適用する場合:

入力特徴マップ:       最大プーリング後:
| 1  3  5  4 |        | 7  8 |
| 7  2  6  8 |   →    | 5  9 |
| 5  1  4  2 |
| 3  5  9  1 |

4×4の特徴マップが2×2に縮小されます。各2×2の領域から最大値が選択されています。

適応型平均プーリング層とは

通常のプーリング層では、カーネルサイズ(ウィンドウの大きさ)とストライド(ウィンドウの移動量)を固定値として指定します。例えば、2×2のカーネルサイズとストライド2の最大プーリングを適用すると、特徴マップの幅と高さは半分になります。

このアプローチでは入力サイズが異なると、プーリング後の出力サイズも変わってしまいます。これは、ネットワークの後段(特に全結合層)が固定サイズの入力を期待する場合に問題となります。

そこで登場するのが 適応型平均プーリング層(Adaptive Average Pooling) です。この層の特徴は:

  1. 出力サイズを固定: 入力サイズに関わらず、常に指定したサイズ(例:7×7)の出力を生成します
  2. 自動調整: カーネルサイズとストライドを入力サイズに応じて自動的に調整します
  3. 平均値計算: 各領域内の値の平均を計算します

例えば、28×28の入力画像と32×32の入力画像のどちらに対しても、適応型平均プーリングを使えば7×7の同じ出力サイズを得ることができます。

このように、適応型平均プーリング層は、異なるサイズの入力に対しても一貫した出力を保証し、ネットワークの後段(全結合層など)との連携をスムーズにする重要な役割を果たします。

Burnにおける実装の詳細

それでは、Burnコードベースから抽出したAdaptiveAvgPool2dの実装を見ていきましょう。

Config 構造体の定義

#[derive(Config)]
pub struct AdaptiveAvgPool2dConfig {
    /// The size of the output.
    pub output_size: [usize; 2],
}

適応型平均プーリング層の設定は非常にシンプルで、出力サイズ(高さと幅)を指定するだけです。[usize; 2]は長さ2の配列で、要素の型がusize(符号なし整数)であることを示しています。

また、#[derive(Config)]アトリビュートは、Burnフレームワーク独自のマクロで、この構造体が設定として使用されることを示し、自動的に必要なメソッドが実装されます。

Module 構造体の定義

#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AdaptiveAvgPool2d {
    /// The size of the output.
    pub output_size: [usize; 2],
}

ここでは実際のモジュール(層)の構造体を定義しています。#[derive(Module, Clone, Debug)]アトリビュートによって、この構造体にBurnのModuleトレイトが自動実装されるほか、Rustの標準的なCloneとDebugトレイトも実装されています。

また、#[module(custom_display)]は、このモジュールがカスタムの表示形式を持つことを指定しています。

順伝播処理の実装

impl AdaptiveAvgPool2d {
    /// Applies the forward pass on the input tensor.
    ///
    /// See [adaptive_avg_pool2d](crate::tensor::module::adaptive_avg_pool2d) for more information.
    ///
    /// # Shapes
    ///
    /// - input: `[batch_size, channels, height_in, width_in]`
    /// - output: `[batch_size, channels, height_out, width_out]`
    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        adaptive_avg_pool2d(input, self.output_size)
    }
}

ここが実際にモジュールの核心部分です。forwardメソッドは、入力テンソルを受け取り、適応型平均プーリング処理を適用して出力テンソルを返します。

特筆すべき点は以下の通りです:

  1. ジェネリクスを使用して、異なるバックエンド(GPU、CPU、特定のハードウェアアクセラレータなど)に対応できるようにしています(<B: Backend>)。
  2. 入力と出力の両方が4次元テンソル(Tensor<B, 4>)であることが型で明示されています。
  3. 実際の処理は、adaptive_avg_pool2d関数に委譲されています。
  4. ドキュメンテーションコメントで、入力と出力のテンソル形状が明確に記載されています。

テンソル操作の実装

実際のプーリング処理を行うadaptive_avg_pool2d関数の実装も見てみます:

/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
where
    B: Backend,
{
    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
        x.primitive.tensor(),
        output_size,
    )))
}

この関数は、入力テンソルxと目標出力サイズoutput_sizeを受け取り、以下の処理を行います:

  1. 入力テンソルから基本データ(primitive tensor)を取得します(x.primitive.tensor())。
  2. バックエンドの実装(B::adaptive_avg_pool2d)を呼び出して実際の計算を行います。
  3. 計算結果を新しいテンソルとして返します(Tensor::new(TensorPrimitive::Float(...)))。

この実装から、Burnがどのように抽象化を行っているかが見て取れます。ユーザーに公開される高レベルのAPIは、実際のハードウェア(CPU/GPU/TPUなど)に依存する計算部分と明確に分離されています。そしてBackendトレイトを実装するバックエンドごとに、adaptive_avg_pool2dメソッドの最適な実装が提供されるわけです。

バックエンド実装の詳細

ここまで、適応型平均プーリング層の高レベルな実装を見てきましたが、実際にはどのようにハードウェア上で計算が行われるのでしょうか?Burnフレームワークでは、異なるハードウェアバックエンド(CPU、GPU、TPUなど)ごとに最適化された実装が提供されています。

CUDAバックエンド(GPU向け)の実装例を見てみましょう。これは、Burnフレームワークがどのように効率的な計算を行っているかを理解する助けになります:

use crate::{
    CubeRuntime,
    element::CubeElement,
    kernel::into_contiguous,
    ops::{max_line_size, numeric::empty_device, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
    tensor::CubeTensor,
};
use burn_tensor::Shape;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

#[cube(launch)]
fn adaptive_avg_pool2d_direct<E: Numeric>(input: &Tensor<Line<E>>, output: &mut Tensor<Line<E>>) {
    if ABSOLUTE_POS >= output.len() {
        terminate!();
    }
    let (out_h, out_w, channels) = (output.shape(1), output.shape(2), output.shape(3));
    let channel_lines = channels / output.line_size();
    let (in_stride_b, in_stride_h, in_stride_w, in_stride_c) = (
        input.stride(0),
        input.stride(1),
        input.stride(2),
        input.stride(3),
    );
    let (in_h, in_w) = (input.shape(1), input.shape(2));
    let c = (ABSOLUTE_POS % channel_lines) * input.line_size();
    let pos = ABSOLUTE_POS / channel_lines;
    let ow = pos % out_w;
    let pos = pos / out_w;
    let oh = pos % out_h;
    let b = pos / out_h;
    let ih_start = start_index(oh, out_h, in_h);
    let ih_end = end_index(oh, out_h, in_h);
    let iw_start = start_index(ow, out_w, in_w);
    let iw_end = end_index(ow, out_w, in_w);
    let mut sum = Line::empty(input.line_size()).fill(E::from_int(0));
    let index_input_0 = b * in_stride_b;
    let index_input_1 = c * in_stride_c;
    for ih in ih_start..ih_end {
        let index_input_2 = ih * in_stride_h;
        for iw in iw_start..iw_end {
            let index_input_3 = iw * in_stride_w;
            let index_input = index_input_0 + index_input_1 + index_input_2 + index_input_3;
            sum += input[index_input / input.line_size()];
        }
    }
    let num_ih = ih_end - ih_start;
    let num_iw = iw_end - iw_start;
    output[ABSOLUTE_POS] = sum / Line::cast_from(num_ih * num_iw);
}

#[cube]
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
    (output_size_index * input_size) / output_size
}

#[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
#[allow(clippy::manual_div_ceil)]
#[cube]
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
    let index = (output_size_index + 1) * input_size;
    let index = (index + output_size - 1) / output_size;
    if input_size < index {
        input_size
    } else {
        index
    }
}

pub(crate) fn adaptive_avg_pool2d<R: CubeRuntime, E: CubeElement>(
    input: CubeTensor<R>,
    output_size: [usize; 2],
) -> CubeTensor<R> {
    let [batch_size, channels, *, *] = input.shape.dims();
    let input = into_contiguous(permute_nchw_to_nhwc(input));
    let line_size = max_line_size(&input);
    let output_shape = Shape::new([batch_size, output_size[0], output_size[1], channels]);
    let num_elems: usize = output_shape.num_elements();
    let output = empty_device::<R, E>(input.client.clone(), input.device.clone(), output_shape);
    let cube_dim = CubeDim::default();
    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
    adaptive_avg_pool2d_direct::launch::<E, R>(
        &input.client,
        cube_count,
        cube_dim,
        input.as_tensor_arg::<E>(line_size),
        output.as_tensor_arg::<E>(line_size),
    );
    permute_nhwc_to_nchw(output)
}

このコードは、GPUで効率的に適応型平均プーリングを実行するためのCUDAカーネル実装です。コードの主要部分を解説します:

メインの関数(adaptive_avg_pool2d)

pub(crate) fn adaptive_avg_pool2d<R: CubeRuntime, E: CubeElement>(
    input: CubeTensor<R>,
    output_size: [usize; 2],
) -> CubeTensor<R> {
    // ...略...
}

この関数は、入力テンソルと目標出力サイズを受け取り、適応型平均プーリングを適用したテンソルを返します。処理の流れは以下の通りです:

  1. 入力テンソルのレイアウトを変換(NCHW形式からNHWC形式へ)
  2. 出力テンソルの形状を計算
  3. GPUで計算するためのパラメータを設定
  4. GPUカーネルを起動
  5. 結果を元の形式(NCHW)に戻して返す

GPUカーネル(adaptive_avg_pool2d_direct)

#[cube(launch)]
fn adaptive_avg_pool2d_direct<E: Numeric>(input: &Tensor<Line<E>>, output: &mut Tensor<Line<E>>) {
    // ...略...
}

このカーネルはGPU上で並列実行され、各スレッドが出力テンソルの一部の計算を担当します。:

  1. 出力位置(ABSOLUTE_POS)から対応する入力領域のインデックスを計算
  2. start_indexend_index関数を使って、各出力要素に対応する入力領域を決定
  3. 入力領域内の全ての値を合計し、領域のサイズで割って平均値を計算
  4. 計算結果を出力テンソルに格納

インデックス計算関数

#[cube]
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
    (output_size_index * input_size) / output_size
}

#[cube]
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
    // ...略...
}

これらの関数は、出力テンソルの各位置に対応する入力テンソルの領域(開始インデックスと終了インデックス)を計算します。これが、適応型平均プーリングの核心部分です。

この実装から、適応型平均プーリングがどのように計算されるかが具体的に分かります:

  1. 出力テンソルの各位置(h, w)に対して、入力テンソルの対応する領域を計算
  2. その領域内の値の平均を計算
  3. 平均値を出力テンソルの対応する位置に格納

Burnフレームワークは低レベルのGPU計算も効率的に行えるよう設計されており、Rustの安全性とGPUの高速性を両立させています。

適応型平均プーリングの活用シーン

適応型平均プーリングは、以下のようなシナリオで用いられます:

  1. 異なるサイズの入力画像を扱うモデル:画像分類や物体検出などのタスクで、様々なサイズの画像を処理する場合
  2. 転移学習:事前学習済みモデルの入力サイズ要件に合わせる必要がある場合
  3. グローバル特徴表現:画像全体を1つの特徴ベクトルに集約する場合(出力サイズを[1, 1]に設定)

まとめ

本記事では、Burnフレームワークにおける適応型平均プーリング層(AdaptiveAvgPool2d)について、基本概念から内部実装まで詳しく解説しました。学んだポイントをまとめると:

  1. プーリングの基本: プーリングとは特徴マップのサイズを縮小し、位置不変性を高める操作です
  2. 適応型プーリングの特徴: 入力サイズに関わらず出力サイズを固定できる便利な機能です
  3. Burnでの実装: Rustの型システムとジェネリクスを活用した安全で効率的な実装がなされています
  4. バックエンド抽象化: Backendトレイトを用いて、様々なハードウェア向けの最適化を可能にしています
  5. 実用的なメリット: 異なるサイズの入力画像を処理するモデルの開発が容易になります

Burnフレームワークは、Rustならではの安全性と性能を深層学習に取り入れる意欲的なプロジェクトですが、このように低レベルの実装まで容易にソースコードを辿ることができ、深層学習フレームワークの内部動作に関する洞察を得られます。

さらにRustやBurnフレームワークについて学びたい方は、Burnの公式ドキュメントRustの公式ドキュメントを参照することをお勧めします!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?