1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

はじめに

近年の AI コンパイラでは、行列積 (MatMul) は最も重要な演算の一つです。
MLIR では MatMul を linalg.matmul として表現できますが、最終的には SIMD 命令や Tensor Core を利用できる形へ変換する必要があります。
本記事では、linalg.matmul -> vector の変換を調べてみます。

Linalg Dialect での MatMul の表現

MatMul とは

行列積 (MatMul) は以下の演算です。

\begin{align}
C_{ij} = \sum_k A_{ik}B_{kj}
\end{align}

例えば、A: 4 x 8, B: 8 x 16, C: 4 x 16 の場合、C = A x B となります。

Linalg Dialect での表現

まずは、MatMul を Linalg Dialect で記述します。

module {
  func.func @matmul(
      %A: memref<4x8xf32>,
      %B: memref<8x16xf32>,
      %C: memref<4x16xf32>) {

    linalg.matmul
      ins(%A, %B : memref<4x8xf32>, memref<8x16xf32>)
      outs(%C : memref<4x16xf32>)

    return
  }
}

Linalg では、iterator_types と indexing_map により演算の意味を表現しています。
iterator_types とは、ループネストの反復処理がどのような性質を持つかを定義する属性です。

タイプ 性質 用途
parallel 並列ループ 各反復を独立して(任意の順序で)実行できるループ
reduction リダクション 出力テンソル/バッファに対して値を累積・集約する
reduction_init 初期化付きリダクション 累積のベースとなる初期値を持つリダクション

MatMul の場合は以下となります。

  • m : parallel
  • n : parallel
  • k : reduction

変換

LLVM では、vectorizeAsLinalgContractionlinalg.matmulvector.contract へ変換するために使われます。

この関数で行なっていることは以下です。

  1. linalgOp が contraction named op か確認する
  2. reduction の combiner を調べる
  3. 入力・出力を vector.transfer_read する
  4. linalg の iterator_types を vector の iterator_types に変換する
  5. linalg の indexing_maps を使って vector.contract を生成する
  6. 結果を vector.transfer_write する

mlir-opt (Homebrew LLVM version 22.1.0) で変換をしてみます。

  • 入力(MLIR)
matmul.mlir
module {
  func.func @matmul(
      %A: tensor<4x8xf32>,
      %B: tensor<8x16xf32>,
      %C: tensor<4x16xf32>)
      -> tensor<4x16xf32> {
    %0 = linalg.matmul
      ins(%A, %B : tensor<4x8xf32>, tensor<8x16xf32>)
      outs(%C : tensor<4x16xf32>)
      -> tensor<4x16xf32>
    return %0 : tensor<4x16xf32>
  }
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(
      %root: !transform.any_op {transform.readonly}) {
    %matmul = transform.structured.match
      ops{["linalg.matmul"]} in %root
      : (!transform.any_op) -> !transform.any_op

    transform.structured.vectorize %matmul
      vector_sizes [4, 16, 8]
      : !transform.any_op

    transform.yield
  }
}
  • 変換コマンド
$ mlir-opt matmul.mlir  --pass-pipeline='builtin.module(transform-interpreter)'
  • 出力(MLIR)
#map = affine_map<(d0, d1) -> (d0, 0, d1)>
#map1 = affine_map<(d0, d1) -> (0, d1, d0)>
module {
  module {
    func.func @matmul(%arg0: tensor<4x8xf32>, %arg1: tensor<8x16xf32>, %arg2: tensor<4x16xf32>) -> tensor<4x16xf32> {
      %c4 = arith.constant 4 : index
      %c16 = arith.constant 16 : index
      %c8 = arith.constant 8 : index
      %c0 = arith.constant 0 : index
      %0 = ub.poison : f32
      %1 = vector.transfer_read %arg0[%c0, %c0], %0 {permutation_map = #map} : tensor<4x8xf32>, vector<4x16x8xf32>
      %2 = ub.poison : f32
      %3 = vector.transfer_read %arg1[%c0, %c0], %2 {permutation_map = #map1} : tensor<8x16xf32>, vector<4x16x8xf32>
      %4 = ub.poison : f32
      %5 = vector.transfer_read %arg2[%c0, %c0], %4 : tensor<4x16xf32>, vector<4x16xf32>
      %6 = arith.mulf %1, %3 : vector<4x16x8xf32>
      %7 = vector.multi_reduction <add>, %6, %5 [2] : vector<4x16x8xf32> to vector<4x16xf32>
      %c0_0 = arith.constant 0 : index
      %8 = vector.transfer_write %7, %arg2[%c0_0, %c0_0] : vector<4x16xf32>, tensor<4x16xf32>
      return %8 : tensor<4x16xf32>
    }
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      transform.structured.vectorize %0 vector_sizes [4, 16, 8] : !transform.any_op
      transform.yield 
    }
  }
}

LLVM 22.1.0 の transform.structured.vectorize では、linalg.matmulvector.contract ではなく、broadcast された 3D vector 上の arith.mulfvector.multi_reduction に変換されました。
これは数式としては

\begin{align}
C_{ij} += A_{ik}B_{kj}
\end{align}

と等価であり、reduction 次元 k が vector.multi_reduction として表現されています。

結果的には、以下のように変換されています。

linalg.matmul
  ↓
vector.transfer_read
arith.mulf
vector.multi_reduction <add>
vector.transfer_write
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?