LoginSignup
1
1

More than 1 year has passed since last update.

PyTorch Mobileでサポートされていない機能をTorchScriptで実現する

Last updated at Posted at 2023-03-14

はじめに

PyTorch Mobileは推論用途で提供されており、Tensorの計算等の機能は提供されていません
一部機能について、補助的にTorchScriptを使って実現できたので、メモ的に残します。

また、本記事はAndroidを念頭に書いていますが、iOSについても適用できる内容だと思います。

目次

実装

tensor同士の計算

コミュニティで質問されていたTensorの計算機能の例として、2つのTensorを足して返却しています。

add_tensors.py
# module which returns sum of the two input tensors
class AddTensors(torch.nn.Module):
    def __init__(self):
        super(AddTensors, self).__init__()

    def forward(self, tensor1, tensor2):
        return torch.add(tensor1, tensor2)

# save module to ptl file
addTensors = AddTensors()
scripted_module = torch.jit.script(addTensors)
optimized_scripted_module = optimize_for_mobile(scripted_module)
optimized_scripted_module._save_for_lite_interpreter("add_tensors.ptl")

TorchScriptで複数inputがある際の例はあまり見当たらなかったのですが、Python、Javaともにそのままカンマ区切りでinputを渡すと使えます。
Python: addTensors(tensor1, tensor2)
Java: addTensorsModule.forward(tensor1, tensor2);

モデルの中間layerのoutputを取得

Python版PyTorchのregister_module_forward_hookを使うと、モデルの中間layerのoutputを取得することができます。

以下の例では、ResNet18のlayer1の出力を取り出すModuleを作成しています。
layer2以降も同様にModule化し、layer1の出力を渡すことで、ResNet18全体を通した出力を得ることもできます。

intermediate_layers_output.py
from torchvision.models import resnet18

# load pretrained resnet18 model
model = resnet18(pretrained=True, progress=True)
model.eval()

# define layers to use
layers = {
    "conv1":    model.conv1,
    "bn1":    model.bn1,
    "relu":    model.relu,
    "maxpool":    model.maxpool,
    "layer1":    model.layer1
}

# add layers to new torch.nn.Sequential
layer_sequential = torch.nn.Sequential()
for key in layers.keys():
    layer = layers[key]
    layer_sequential.add_module(key, layer)

# save module to ptl file
scripted_module = torch.jit.script(layer_sequential)
optimized_scripted_module = optimize_for_mobile(scripted_module)
optimized_scripted_module._save_for_lite_interpreter('intermediate_layers_output.ptl')

pickleファイルに保存したtensorの読み込み

Androidではpickleファイルに保存したTensorを直接ロードすることはできませんが、Python上であらかじめpickleファイルをロードしておき、Moduleのpropertyとして渡しておくことで、Androidでも読み出せるようになります。

(本項目はpickleからロードしたTensor以外にも適用できますが、過去の自分を含めてpickleファイルのロード方法としての需要がありそうなので、この見出しとしました)

load_pickle.py
# module which returns the value loaded from the pickle file
class LoadTensor(torch.nn.Module):
    loaded_value = None

    def __init__(self, loaded_value):
        super(LoadTensor, self).__init__()
        self.loaded_value = loaded_value

    def forward(self):
        return self.loaded_value

# load tensor from pickle
with open('tensor.pkl', 'rb') as f:
    loaded_tensor = pickle.load(f)

# save module to ptl file
loadTensor = LoadTensor(loaded_tensor)
scripted_module = torch.jit.script(loadTensor)
optimized_scripted_module = optimize_for_mobile(scripted_module)
optimized_scripted_module._save_for_lite_interpreter("load_pickle.ptl")

おわりに

TorchScriptではscipyやnumpyといった外部ライブラリを使えないという制約はありますが、Mobile版で提供されていない機能についてもPython版のコードを元に動作させることができる非常に強力な機能だと思います。
今回は割愛しましたが、TorchScriptを使えば、permuteなどもfor文を回さずに実現することができ、うまく使うことで、モバイルへの機械学習周辺の機能を組み込むことが容易になります。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1