Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

HaskellからBLASを叩く

Haskellで(数値的な、でかい)行列を扱うライブラリーとしては hmatrix がありますが、諸事情から hmatrix を避けたい場合があります。

諸事情の例:

  • インターフェースに癖がある(<> がPreludeと競合する。これはhmatrixの歴史が長いのである程度は仕方がない)
  • Numeric.LinearAlgebra.StaticDoubleComplex Double しか扱えない(Float を扱いたい)
  • MacPorts導入済みのMacではデフォルトでICU絡みのリンクエラーが出る(disable-default-paths オプションを指定すれば回避できる)

そこでHaskellで最強の線形代数ライブラリーを作りたくなるわけですが、ベクトルの和や行列の積は手書きではなく既存のライブラリー、BLASを使った方が有利です(hmatrixも内部的にはBLASを使っています)。というわけで、この記事ではHaskellからBLASを触る方法を見てみます。

Hackageを漁る

Hackageを BLAS で検索するといくつかそれらしいの(低レベルなバインディング)が出てきます。

筆者の独断と偏見で評価すると、

  • blas: Deprecated in favor of hblas らしい
  • hblas
  • blas-ffi
    • pkg-configで blas が見える必要がある。「パッケージオプションでOpenBLASを選択する」というようなことはできない。
  • blas-hs
    • パッケージオプションでNetlib, MKL, OpenBLAS, Accelerate.frameworkを選択できる。
    • gemmなど一部の関数のドキュメントには関数のパラメーターの説明がついている。地味に嬉しい。
  • linear-algebra-cblas
    • autoconfを使っている。筆者の環境では Setup.lhs でコンパイルエラーが出た。

なので、 blas-hs が一番良さそうに思えます。

ちなみに、macOS Big SurでAccelerate.frameworkを使う場合は、以下の問題によりGHC 8.8.4 / GHC 8.10.2(およびそれ以前のバージョン)では can't load framework: Accelerate (not found) みたいなエラーが出ます。GHC 8.10.3を使いましょう。

blas-hsを使う

テスト用のコードは https://github.com/minoki/blas-test に置きました。

blas-hsはそれぞれの関数・型クラスについて Safe 版と Unsafe 版を用意しています。これらは foreign import の種類に対応し、ざっくりとした説明は次の通りです:

  • safe: 外部呼び出し中に他のHaskellスレッド(?)を実行できるが、呼び出しコストがかかる。
  • unsafe: 外部呼び出し中は当該Haskellスレッド(?)の実行はブロックされる。呼び出しコストが低い。
    • ちなみに、GHC 8.4以降ではunsafe呼び出し中はGCが実行されないことが保証されるので、pinnedではない通常の ByteArray# を受け渡しできる……という話もあるがblas-hsは Ptr 経由のインターフェースしか用意していないので関係ない。

blas-hsは、オーバーロードされた Blas.Generic.{Safe,Unsafe} と単相な Blas.Specialized.*.{Safe,Unsafe} を用意しています。

blas-hsは Ptr 経由のインターフェースしか提供していないので、コンテナーを使うなら Storable 系を使いましょう。ここでは Data.Vector.Storable を使います。

サンプルコード:

import qualified Blas.Generic.Safe as BLAS
import qualified Blas.Primitive.Types as BLAS
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Control.Monad (forM_)

matMulVS :: Int -> Int -> Int -> VS.Vector Double -> VS.Vector Double -> IO (VS.Vector Double)
matMulVS m n k a b
  | VS.length a == m * k && VS.length b == k * n
  = VS.unsafeWith a $ \ptrA -> VS.unsafeWith b $ \ptrB -> do
      c <- VSM.new (m * n)
      VSM.unsafeWith c $ \ptrC ->
        BLAS.gemm BLAS.RowMajor BLAS.NoTrans BLAS.NoTrans m n k 1.0 ptrA k ptrB n 0.0 ptrC n
      VS.unsafeFreeze c
  | otherwise = error "size mismatch"

printMat :: Int -> Int -> VS.Vector Double -> IO ()
printMat m n v
  | VS.length v == m * n
  = forM_ [0..m-1] $ \i -> do
      print $ VS.slice (n * i) n v
  | otherwise = error "size mismatch"

main :: IO ()
main = do
  let a :: VS.Vector Double
      a = VS.fromList [ 1.0, 2.0, 3.0
                      , 2.0, 3.0, 4.0
                      , 4.0, 5.0, 6.0
                      ]

      b :: VS.Vector Double
      b = VS.fromList [ 1.0, 2.0, 3.0, 2.0
                      , 2.0, -2.0, 4.0, 0.0
                      , 1.0, 5.0, 6.0, -1.0
                      ]
  c <- matMulVS 3 4 3 a b
  printMat 3 4 c

実行結果:

$ stack exec blas-test-blashs
[8.0,13.0,29.0,-1.0]
[12.0,18.0,42.0,0.0]
[20.0,28.0,68.0,2.0]

cublasを使う

最近はGPUに行列計算させている方も多いでしょう。NVIDIAのGPUであればcuBLASというやつが使えますが、cuBLASにもHaskellバインディングがあります。せっかくなので使ってみましょう。

どうやらcuBLASはrow-majorを指定できないようです。ここでは、row-majorな行列のオペランドの順番をひっくり返してゴニョゴニョしてみました。

import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Control.Monad (forM_)
import Foreign.CUDA.Runtime.Marshal as CUDA
import Foreign.CUDA.BLAS as CuBLAS
import Foreign.CUDA.BLAS.Level3 as CuBLAS
import Foreign.Marshal (with)

printMat :: Int -> Int -> VU.Vector Double -> IO ()
printMat m n v
  | VU.length v == m * n
  = forM_ [0..m-1] $ \i -> do
      print $ VU.slice (n * i) n v
  | otherwise = error "size mismatch"

main :: IO ()
main = do
  handle <- CuBLAS.create
  ptrA <- CUDA.newListArray [ 1.0, 2.0, 3.0
                            , 2.0, 3.0, 4.0
                            , 4.0, 5.0, 6.0
                            ]
  ptrB <- CUDA.newListArray [ 1.0, 2.0, 3.0, 2.0
                            , 2.0, -2.0, 4.0, 0.0
                            , 1.0, 5.0, 6.0, -1.0
                            ]
  ptrC <- CUDA.mallocArray (3 * 4)
  let m = 3
      n = 4
      k = 3
  with 1.0 $ \ptrAlpha ->
    with 0.0 $ \ptrBeta ->
      -- A: m * k
      -- B: k * n
      CuBLAS.dgemm handle CuBLAS.N CuBLAS.N n m k ptrAlpha ptrB n ptrA k ptrBeta ptrC n
  c <- CUDA.peekListArray (3 * 4) ptrC
  CUDA.free ptrA
  CUDA.free ptrB
  CUDA.free ptrC
  printMat 3 4 (VU.fromList c)
  CuBLAS.destroy handle

実行結果:

$ stack exec blas-test-cuda
[8.0,13.0,29.0,-1.0]
[12.0,18.0,42.0,0.0]
[20.0,28.0,68.0,2.0]

GPUの嬌声は聞こえましたか?

mod_poppo
最近は浮動小数点数オタクをやっています。
https://blog.miz-ar.info/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away