LoginSignup
1
0

macOSでtch-rsを使ってみた

Last updated at Posted at 2024-02-25

PyTorch の Rust バインディングである tch-rs を M2 の Mac で利用してみました。

はじめに

tch-rs は libtorch が必要となっており、基本的には以下の 2通りの使い方があるようです。

今回は両方試してみました。

使い方
(a) libtorch を手動インストール
(b) PyTorch の libtorch 利用

サンプルコード

ここでは、tch-rs を用いた以下のサンプルコードをビルド・実行する事で動作確認します。

sample.safetensors ファイルから重みとバイアスをロードして単純な計算を実施するようになっています。
また、実行時引数で CPU と MPS(GPU)を選べるようにしています。

src/main.rs
use tch::{
    nn::{self, Module},
    Device, Tensor,
};

use std::env;

#[derive(Debug)]
struct Net {
    fc1: nn::Linear,
    fc2: nn::Linear,
}

impl Net {
    fn new(vs: &nn::Path) -> Self {
        let fc1 = nn::linear(vs / "fc1", 2, 3, Default::default());
        let fc2 = nn::linear(vs / "fc2", 3, 1, Default::default());

        Self { fc1, fc2 }
    }
}

impl nn::Module for Net {
    fn forward(&self, xs: &Tensor) -> Tensor {
        xs.apply(&self.fc1).relu().apply(&self.fc2)
    }
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = if let Some(_) = env::args().skip(1).next() {
        Device::Mps
    } else {
        Device::Cpu
    };

    let mut vs = nn::VarStore::new(device);
    let net = Net::new(&vs.root());

    vs.load("sample.safetensors")?;

    let x = Tensor::from_slice(&[1.0f32, 2.0]).to(device).view((1, 2));

    let y = net.forward(&x);

    y.print();

    Ok(())
}
Cargo.toml
[package]
name = "sample"
version = "0.1.0"
edition = "2021"

[dependencies]
tch = "0.15"

なお、sample.safetensors ファイルは次の Python スクリプトで作成しました。

sample.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

net = Net()

net.fc1.weight = nn.Parameter(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
net.fc1.bias = nn.Parameter(torch.tensor([0.7, 0.8, 0.9]))

net.fc2.weight = nn.Parameter(torch.tensor([[0.4, 0.3, 0.2]]))
net.fc2.bias = nn.Parameter(torch.tensor([0.1]))

print(net.state_dict())

x = torch.tensor([[1.0, 2.0]])
y = net(x)

print(y)

save_file(net.state_dict(), 'sample.safetensors')
sample.py 実行結果
$ python sample.py
OrderedDict({'fc1.weight': tensor([[0.1000, 0.2000],
        [0.3000, 0.4000],
        [0.5000, 0.6000]]), 'fc1.bias': tensor([0.7000, 0.8000, 0.9000]), 'fc2.weight': tensor([[0.4000, 0.3000, 0.2000]]), 'fc2.bias': tensor([0.1000])})
tensor([[1.6700]], grad_fn=<AddmmBackward0>)

(a) libtorch を手動インストール

https://pytorch.org/get-started/locally/ から libtorch-macos-arm64 をダウンロードして解凍するだけでした。

libtorch インストール
$ curl -O https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.2.1.zip
$ unzip libtorch-macos-arm64-2.2.1.zip

ただし、libomp を参照できないと実行に失敗するため Homebrew 等でインストールしておきます。

libomp インストール
$ brew install libomp

あとは環境変数を設定するだけです。

なお、今回使用した tch-rs 0.15.0 は libtorch 2.2.1 に対応しておらず、tch-rs 処理内のバージョンチェックに失敗してそのままではビルドできませんでした。
バージョンチェックを回避して正常にビルドするには LIBTORCH_BYPASS_VERSION_CHECK 環境変数を有効化する必要がありました。

環境変数の設定
$ export LIBTORCH_BYPASS_VERSION_CHECK=1

$ export LIBTORCH=$(pwd)/libtorch
$ export DYLD_LIBRARY_PATH=$LIBTORCH/lib:/opt/homebrew/opt/libomp/lib:$DYLD_LIBRARY_PATH

これで問題なくビルド・実行できました。

実行1
$ cargo run
...
 1.6700
[ CPUFloatType{1,1} ]
実行2
$ cargo run 1
...
 1.6700
[ MPSFloatType{1,1} ]

(b) PyTorch の libtorch 利用

ここでは、Homebrew でインストールした Anaconda で環境(tch1)を作成して PyTorch をインストールする事にします。

Anaconda 環境の作成
$ conda create --name tch1 python=3.12

PyTorch はとりあえず pip でインストールしました。
インストール先は Anaconda の envs/tch1/lib/python3.12/site-packages/torch になります。

PyTorchインストール
$ conda run -n tch1 --no-capture-output pip install torch

あとは環境変数を設定するだけです。

PyTorch に含まれる libtorch を tch-rs から使うには LIBTORCH_USE_PYTORCH 環境変数を有効化する必要があります。

また、python コマンドを呼び出せるようにする必要があるため、PATH 環境変数に Anaconda の envs/tch1/bin を追加しています。

環境変数の設定
$ export LIBTORCH_BYPASS_VERSION_CHECK=1

$ export LIBTORCH_USE_PYTORCH=1
$ export DYLD_LIBRARY_PATH=/opt/homebrew/anaconda3/envs/tch1/lib/python3.12/site-packages/torch/lib:$DYLD_LIBRARY_PATH
$ export PATH=/opt/homebrew/anaconda3/envs/tch1/bin:$PATH

これでビルド・実行できました。

実行1
$ cargo run
...
 1.6700
[ CPUFloatType{1,1} ]
実行2
$ cargo run 1
...
 1.6700
[ MPSFloatType{1,1} ]
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0