15
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのモデルをiOSで利用する - LibTorchをiOSプロジェクトに組み込む手順

Posted at

PyTorchで作成した**.ptモデルをiOSで直接(Core MLモデルに変換せずに)使う**方法。

MetalやNeural Engineに最適化されることが期待されるので基本的にはCore MLに変換してから使ったほうが良いのだが、

  • PyTorchモデルをCore ML Toolsで変換するにはいったんONNXフォーマットに変換するといった煩雑さがある1
  • PyTorchモデルをAndroidと共通で使いたい

こういった場合にそのまま直接組み込むという線も出てくる。

PyTorchモデルを扱うC++ライブラリがCocoaPods対応してるので、自分のアプリへの導入はめちゃくちゃ簡単。

以下その手順。

1. LibTorchのインストール

Podfileに以下を追記して、

pod 'LibTorch', '~>1.5.0'

pod installを実行。

2. プロジェクト設定の変更

3. PyTorchモデルを追加する

.ptファイルをXcodeプロジェクトに追加する。

4. ブリッジ実装を書く

LibTorchとモデルを使って推論処理を行うラッパーをObjective-C++で実装する。ここはモデルによって実装が変わってくる。

たとえば公式サンプルのHelloWorldに入っているTorchModule.h/mmの実装はこんな感じ:

TorchModule.h
#import <Foundation/Foundation.h>

NS_ASSUME_NONNULL_BEGIN

@interface TorchModule : NSObject

- (nullable instancetype)initWithFileAtPath:(NSString*)filePath
    NS_SWIFT_NAME(init(fileAtPath:))NS_DESIGNATED_INITIALIZER;
+ (instancetype)new NS_UNAVAILABLE;
- (instancetype)init NS_UNAVAILABLE;
- (nullable NSArray<NSNumber*>*)predictImage:(void*)imageBuffer NS_SWIFT_NAME(predict(image:));

@end

NS_ASSUME_NONNULL_END
TorchModule.mm
#import "TorchModule.h"
#import <LibTorch/LibTorch.h>

@implementation TorchModule {
 @protected
  torch::jit::script::Module _impl;
}

- (nullable instancetype)initWithFileAtPath:(NSString*)filePath {
  self = [super init];
  if (self) {
    try {
      _impl = torch::jit::load(filePath.UTF8String);
      _impl.eval();
    } catch (const std::exception& exception) {
      NSLog(@"%s", exception.what());
      return nil;
    }
  }
  return self;
}

- (NSArray<NSNumber*>*)predictImage:(void*)imageBuffer {
  try {
    at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, 224, 224}, at::kFloat);
    torch::autograd::AutoGradMode guard(false);
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    auto outputTensor = _impl.forward({tensor}).toTensor();
    float* floatBuffer = outputTensor.data_ptr<float>();
    if (!floatBuffer) {
      return nil;
    }
    NSMutableArray* results = [[NSMutableArray alloc] init];
    for (int i = 0; i < 1000; i++) {
      [results addObject:@(floatBuffer[i])];
    }
    return [results copy];
  } catch (const std::exception& exception) {
    NSLog(@"%s", exception.what());
  }
  return nil;
}

@end

5. Swiftから呼ぶ

4で実装したクラスをSwiftから使って推論処理を行う。

たとえばHelloWorldサンプルでは次のようにモデルファイル(model.pt)のパスを渡してTorchModuleクラスを初期化している:

private lazy var module: TorchModule = {
    if let filePath = Bundle.main.path(forResource: "model", ofType: "pt"),
        let module = TorchModule(fileAtPath: filePath) {
        return module
    } else {
        fatalError("Can't find the model file!")
    }
}()

推論処理の実行:

let resizedImage = image.resized(to: CGSize(width: 224, height: 224))
guard var pixelBuffer = resizedImage.normalized() else {
    return
}
guard let outputs = module.predict(image: UnsafeMutableRawPointer(&pixelBuffer)) else { return }

ちなみに・・・このサンプルの実装でいうとリサイズやノーマライズといったピクセルデータにアクセスする(=GPU向き)前処理をCPUで行っていて、やっぱり基本的には(PyTorch Mobile/LibTorchを使うのではなく)Core MLを利用して前処理〜推論処理まで一貫してGPU(Metal)およびNeural Engineで行うようにしたほうが良いように思う。

Neural Engineについては以下の記事を参照:

  1. coremltools 4.0から直接Core MLモデルに変換できるようになったが、まだベータなのと、ちょっと使ってみた感じでは生成されるモデルがiOS 14以上でしか使用できない

15
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?