こんにちは。
今回はテンソルコンパイラのIREEについて、自己学習を兼ねて書いてみます。IREEはGoogleが開発したOSSのテンソルコンパイラです。LLVMを基盤としており、中間言語としてMLIRを活用していることで知られています。
IREEの概要
IREE1は機械学習モデルの実行用のコンパイラおよびランタイムフレームワークです。
機械学習の推論は、大規模なデプロイメントから低電力デバイスまで様々なターゲットがユースケースとして存在し、モデルを扱うフレームワークも多様化しています。
その過程で、モデルを複数段階の中間表現に分けて、特定のハードウェア向けに最適化するコンパイラベースのアプローチが出てきました。IREEはその一つであり、MLIRインフラストラクチャに基づく、エンドツーエンドのコンパイラ・ランタイムフレームワークを持っています。(他の例として、Apatch TVM2は複数の中間表現で高レベルのグラフ再構成と低レベルの命令最適化を行うようです。)
MLIR3は、テンソルコンパイラや言語コンパイラで着目されている新しい中間表現構造です。IREEはこれを活用することで、中間表現(IR)を多段階(Multi-Level)で表現し、機械学習モデルを段階的にハードウェアレベルの言語(アセンブリ)に変換していくことができます。これを実現するための MLIR の主な手段はDialect(方言)です。Dialectはプログラムを表現できる操作と型のコレクションです。MLIRではこのDialectを使用して、異なる抽象化レベルで中間表現を補完できます。
IREEで扱っているMLIR Dialectには例えば以下があります。
- TOSA/MHLO: 機械学習プログラムをテンソル演算の表現
- Linalg: ネストされたループ計算の表現
- Vector: 仮想ベクトル操作などの表現
- LLVM: 他のDialectをLLVM IRの対応する操作と型を紐付けるための表現
コンパイルとランタイム実行
PyTorchのサンプルプログラムがあったので、これを使ってIREEでAOTコンパイルを見てみたいと思います。ここでのターゲットはx86-64 CPUです。
- PyTorchとIREEの準備
$ pip install torch --index-url https://download.pytorch.org/whl/test/cpu
$ pip install iree-turbine
- サンプルプログラム4
forwardでa @ b + c
の演算をしているプログラムです
import torch
import iree.turbine.aot as aot
torch.manual_seed(0)
# Define the `nn.Module` to export.
class LinearModule(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))
def forward(self, input):
return (input @ self.weight) + self.bias
linear_module = LinearModule(4, 3)
# Export the program using the simple API.
example_arg = torch.randn(4)
export_output = aot.export(linear_module, example_arg)
# Output MLIR then continue from native tools
mlir_file_path = "./linear_module_pytorch.mlirbc"
vmfb_file_path = "./linear_module_pytorch.llvmcpu.vmfb"
print("Exported .mlir:")
export_output.print_readable()
export_output.save_mlir(mlir_file_path)
print("compiling and running...")
!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host --iree-llvmcpu-target-triple=x86_64-pc-linux-elf {mlir_file_path} -o {vmfb_file_path}
!iree-run-module --module={vmfb_file_path} --device=local-task --input="4xf32=[1.0, 2.0, 3.0, 4.0]"
- PyTorchのコンパイル、ランタイム実行のフロー5
上の図に沿ってコードの変化を見てみます。
まず、PyTorchのモデルはiree-turbine
というフロントエンドのコマンドでMLIRコードに変換されます。このMLIRコードはtorch dialectで書かれています。以下がサンプルプログラムのモジュール部分です。
module @module {
func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch.vtensor.literal(dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>) : !torch.vtensor<[4,3],f32>
%1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3],f32>
%2 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>
%int1 = torch.constant.int 1
%3 = torch.aten.add.Tensor %1, %2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
return %3 : !torch.vtensor<[3],f32>
}
}
torch dialectで書かれたMLIRはiree-compile
に読み込ませることができます。
iree-compile
はiree-run-module
で実行可能なコードを生成します。デフォルトではIREEのVMのフォーマット(VMFB)になっています。IREEではVMによる実行がメインのようです6。VMFBのアセンブリをC言語に変換して出力することもできます。
最後のiree-run-module
はVM上でVMFBのバイナリを動作させています。
iree-compile
のパイプラインではMLIRが活用されているのでもう少し見てみます。
入力のtorch dialectはiree-compile
のインポーターでlinalg dialectなどに変換されます。
以下のコードはインポーターの後のMLIRコードを取り出したものですが、torch.aten.matmul
, torch.aten.add
がlinalg.vecmat
とlinalg.generic
に変換されています。因みに、linalg.generic
のテンソル演算は要素単位の演算に比べて命令を融合しやすいので、最適化の機会が増えるそうです。
1 module @module {
2 util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
3 %cst = arith.constant dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>
4 %cst_0 = arith.constant dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>
5 %cst_1 = arith.constant 0.000000e+00 : f32
6 %0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<4xf32>
7 %1 = tensor.empty() : tensor<3xf32>
8 %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<3xf32>) -> tensor<3xf32>
9 %3 = linalg.vecmat ins(%0, %cst_0 : tensor<4xf32>, tensor<4x3xf32>) outs(%2 : tensor<3xf32>) -> tensor<3xf32>
10 %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3, %cst : tensor<3xf32>, tensor<3xf32>) outs(%1 : tensor<3xf32>) {
11 ^bb0(%in: f32, %in_2: f32, %out: f32):
12 %7 = arith.addf %in, %in_2 : f32
13 linalg.yield %7 : f32
14 } -> tensor<3xf32>
15 %5 = hal.tensor.barrier join(%4 : tensor<3xf32>) => %arg2 : !hal.fence
16 %6 = hal.tensor.export %5 : tensor<3xf32> -> !hal.buffer_view
17 util.return %6 : !hal.buffer_view
18 }
:
また、iree-compile
はホストのデータフローとテンソル計算を分けるためのflow dialectやワークロードのディスパッチを定義してスケジューリングするためのstream dialect、非同期実行の命令を発行するためのハードウェア抽象化表現のhal dialectという方言を内部で扱います。このhal dialectではターゲットのアーキ情報を保持しているようです(下のコード2行目,26行目)。
以下はiree-compile
の内の後段のMLIRコードです。1~24行目と25~51行目でhal.executable
によりディスパッチ領域が区別されており、dispatch_1ではlinalg.mmt4d
、dispatch_2ではlinalg.generic
があるので、これらの命令が非同期実行される別々のワークグループに入ったことが分かります。これらはflow dialect, stream dialect、そしてhal dialectが表現している情報(セマンティクス)からコンパイラがデータ依存関係やスケジューリングを順序立てて判断したものになっています。IREEではこれらのDialectが最適化に重要な役割を果たしているようです。
1 hal.executable private @main$async_dispatch_0 {
2 hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "skylake", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,-amx-fp8,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-tsxldtrk,-sm3,-ptwrite,-widekl,-movrs,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,+rtm,+adx,+avx2,-hreset,-movdiri,-serialize,-vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,-gfni,-avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,-nf,-amx-tf32,-amx-avx512,+fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-amx-transpose,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,-clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,-amx-movrs,-rdpid,-fma4,-avx512vbmi,-shstk,-vaes,-waitpkg,+sgx,+fxsr,-avx512dq,-sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
3 hal.executable.export public @main$async_dispatch_0_mmt4d_1x1x4x1x8x1_f32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
4 ^bb0(%arg0: !hal.device):
5 %x, %y, %z = flow.dispatch.workgroup_count_from_slice
6 hal.return %x, %y, %z : index, index, index
7 }
8 builtin.module {
9 func.func @main$async_dispatch_0_mmt4d_1x1x4x1x8x1_f32() {
10 %cst = arith.constant dense<[[[[1.54099607], [-0.293428898], [-2.17878938], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]], [[0.568431258], [-1.08452237], [-1.39859545], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]], [[0.403346837], [0.838026344], [-0.719257593], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]], [[-0.403343529], [-0.596635341], [0.182036489], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]]]]> : tensor<1x4x8x1xf32>
11 %cst_0 = arith.constant 0.000000e+00 : f32
12 %c0 = arith.constant 0 : index
13 %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x4x1x1xf32>>
14 %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x1x1x8xf32>>
15 %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 4, 1, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4x1x1xf32>> -> tensor<1x4x1x1xf32>
16 %3 = tensor.empty() : tensor<1x1x1x8xf32>
17 %4 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x1x1x8xf32>) -> tensor<1x1x1x8xf32>
18 %5 = linalg.mmt4d ins(%2, %cst : tensor<1x4x1x1xf32>, tensor<1x4x8x1xf32>) outs(%4 : tensor<1x1x1x8xf32>) -> tensor<1x1x1x8xf32>
19 flow.dispatch.tensor.store %5, %1, offsets = [0, 0, 0, 0], sizes = [1, 1, 1, 8], strides = [1, 1, 1, 1] : tensor<1x1x1x8xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x1x1x8xf32>>
20 return
21 }
22 }
23 }
24 }
25 hal.executable private @main$async_dispatch_1 {
26 hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "skylake", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,-amx-fp8,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-tsxldtrk,-sm3,-ptwrite,-widekl,-movrs,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,+rtm,+adx,+avx2,-hreset,-movdiri,-serialize,-vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,-gfni,-avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,-nf,-amx-tf32,-amx-avx512,+fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-amx-transpose,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,-clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,-amx-movrs,-rdpid,-fma4,-avx512vbmi,-shstk,-vaes,-waitpkg,+sgx,+fxsr,-avx512dq,-sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
27 hal.executable.export public @main$async_dispatch_1_unpack_elementwise_3_f32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
28 ^bb0(%arg0: !hal.device):
29 %x, %y, %z = flow.dispatch.workgroup_count_from_slice
30 hal.return %x, %y, %z : index, index, index
31 }
32 builtin.module {
33 func.func @main$async_dispatch_1_unpack_elementwise_3_f32() {
34 %cst = arith.constant dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>
35 %c0 = arith.constant 0 : index
36 %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x8xf32>>
37 %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<3xf32>>
38 %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x8xf32>> -> tensor<1x8xf32>
39 %3 = tensor.empty() : tensor<3xf32>
40 %unpack = tensor.unpack %2 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [8] into %3 : tensor<1x8xf32> -> tensor<3xf32>
41 %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%unpack, %cst : tensor<3xf32>, tensor<3xf32>) outs(%3 : tensor<3xf32>) {
42 ^bb0(%in: f32, %in_0: f32, %out: f32):
43 %5 = arith.addf %in, %in_0 : f32
44 linalg.yield %5 : f32
45 } -> tensor<3xf32>
46 flow.dispatch.tensor.store %4, %1, offsets = [0], sizes = [3], strides = [1] : tensor<3xf32> -> !flow.dispatch.tensor<writeonly:tensor<3xf32>>
47 return
48 }
49 }
50 }
51 }
この記事は全体的にIREEの論文7などを参考に書きました。
読んでいただきありがとうございます。勉強中なので詳しいことはあまり書けませんでしたが、参考になれば幸いです。良い年末をお過ごしください。
-
https://github.com/iree-org/iree
IREE: Intermediate Representation Execution Environment,「イーリー」と発音。同音の"eerie"は英語で「不気味な」という意味で、ロゴアイコンはゴースト👻になっている ↩ -
https://iree.dev/guides/ml-frameworks/pytorch/#ahead-of-time-aot-export ↩
-
https://iree.dev/developers/design-docs/vm/#lowering-from-the-vm-to-c ↩