Posted at

テンソル形状のコンパイル時検査+実行時アサーション化を今のHaskellで書いてみる


はじめに

Edward Z. Yang 氏の記事 A compile-time debugger that helps you write tensor shape checks は、テンソルの形状の検査を依存型でコンパイル時に対話的にやりつつ、静的に決まらない部分や静的な検証が難しい部分はアサーションとして実行時検査に遅延するという仕組みの提案。

対話的なアサーションの挿入部分とかはともかく、型付けとアサーションの部分は今のHaskellでも大体出来るなぁと思ったので試しに書いてみた。 (元記事自体Haskellっぽい構文で書かれているし、 Edward Z. Yang 氏自身、ここまでは出来るのは当然知って書いているのだと思うけれど)


結果

元記事での、アサーションを挿入した最終的なプログラムは以下:

main() {

x = load("tensor1.t")
y = load("tensor2.t")
assert_eq_nat!(x.size(), 2)
assert_eq_nat!(y.size(), 2)
assert_eq_nat!(x.size(1), y.size(0))
matmul(x, y)

これを Haskell で書いてみたものが以下。 使っている関数の定義は後述する。 このままだと型検査が通るけれど、途中の3行をコメントアウトすると型エラーが起こるようになる。

main :: IO (DynShape Tensor)

main = do
DynShape x <- load "tensor1.t"
DynShape y <- load "tensor2.t"

-- If we comment out following three lines, it fails to type check.
IsLength2 _ x2 <- assertIsLength2 (shape x)
IsLength2 y1 _ <- assertIsLength2 (shape y)
Refl <- assertEqNat x2 y1

return $ DynShape (matmul x y)


実装

まずは使う言語拡張の有効化とモジュールのインポート。

{-# LANGUAGE DataKinds  #-}

{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
import Data.Proxy
import Data.Type.Equality
import GHC.Natural
import GHC.TypeNats


Shape の定義

型レベルの自然数のリストとして Shape を定義し、 また型レベルのShapeから値レベルのShape(自然数のリスト)を取得するための関数として shapeVal を定義する。

-- kind

type Shape = [Nat]

data SShape (s :: Shape) where
SEmpty :: SShape '[]
SCons :: KnownNat n => Proxy n -> SShape s -> SShape (n : s)

class KnownShape (s :: Shape) where
shapeSing :: SShape s

instance KnownShape '[] where
shapeSing = SEmpty

instance (KnownNat x, KnownShape xs) => KnownShape (x : xs) where
shapeSing = SCons Proxy shapeSing

shapeVal :: forall s proxy. KnownShape s => proxy s -> [Natural]
shapeVal _ = f (shapeSing :: SShape s)
where
f :: forall s2. SShape s2 -> [Natural]
f SEmpty = []
f (SCons x xs) = natVal x : f xs


テンソルの定義

Shape を型パラメータに持つ型として Tensor を定義しする。 今回の目的は型検査の proof of concept なので、 Tensor の値に関しては実装をサボって、 実際の多次元配列を扱わず、単に T という単一の値として扱う。

data Tensor (s :: Shape) = T

deriving (Show, Eq)

shape :: Tensor s -> Proxy s
shape _ = Proxy

次に演算として matmul を定義しておく。 Shapeが整合的であることを型レベルで表現することができる。

matmul :: Tensor [a,b] -> Tensor [b,c] -> Tensor [a,c]

matmul _ _ = T

次にShapeが動的に決まる場合の型を定義する。 元記事では exists (size : List(Nat)). Tensor(size) というような書き方をしていたが、 Haskellには存在型のための構文はないので、Shapeの存在量化のために DynShape という型を定義する。 すると、 exists (size : List(Nat)). Tensor(size)DynShape Tensor に対応する。

data DynShape f where

DynShape :: KnownShape shape => f shape -> DynShape f

テンソルの形状は静的に決まらない例として、テンソルをファイルから読み込む関数 load を定義する。 結果の型は DynShape Tensor になる。 また、ここでも実装をサボって、任意のファイルを読む代わりに、 "tensor1.t" という "tensor2.t" という名前に対して それぞれ 2x3 と 3x2 のテンソルを返すモックにしておく。

load :: String -> IO (DynShape Tensor)

load fname =
case fname of
"tensor1.t" -> return (DynShape (T :: Tensor [2,3]))
"tensor2.t" -> return (DynShape (T :: Tensor [3,2]))
_ -> error "not found"


アサーションの定義

最後にアサーションを定義する。

まずは、Shapeの長さが2であるというアサーション。 shapeSing をパターンマッチして長さが2だったら、その証拠として IsLength2 の値を返す。

data IsLength2 (shape :: Shape) where

IsLength2 :: (KnownNat a, KnownNat b) => Proxy a -> Proxy b -> IsLength2 [a,b]

assertIsLength2 :: forall s. KnownShape s => Proxy s -> IO (IsLength2 s)
assertIsLength2 _ =
case shapeSing :: SShape s of
SCons n1 (SCons n2 SEmpty) -> return $ IsLength2 n1 n2
_ -> error ("length " ++ show (shapeVal (Proxy :: Proxy s)) ++ " /= 2")

次に2つの型レベル自然数が等しいというアサーション。 こちらも同様だけど、 GHC.TypeNatssameNat という関数が定義済みなので、それを使う。

assertEqNat :: (KnownNat a, KnownNat b) => Proxy a -> Proxy b -> IO (a :~: b)

assertEqNat x y =
case sameNat x y of
Just refl -> return refl
Nothing -> error $ show (natVal x) ++ " /= " ++ show (natVal y)


再び最初のプログラム

改めて最初のプログラムを掲載する。 実装を読んだので何をやっているかはもはや明らかだろう。

main :: IO (DynShape Tensor)

main = do
DynShape x <- load "tensor1.t"
DynShape y <- load "tensor2.t"

-- If we comment out following three lines, it fails to type check.
IsLength2 _ x2 <- assertIsLength2 (shape x)
IsLength2 y1 _ <- assertIsLength2 (shape y)
Refl <- assertEqNat x2 y1

return $ DynShape (matmul x y)

x, yTensor s1, Tensor s2 のような型(s1, s2 はスコーレム変数)となっている。

ここで、 IsLength2 _ x2 <- assertIsLength2 (shape x) のパターンマッチによって型の等式 s1 ~ [_, x2] が、 IsLength2 y1 _ <- assertIsLength2 (shape y) のパターンマッチによって、型の等式 s2 ~ [y1, _] が導入され、 最後に Refl <- assertEqNat x2 y1 のパターンマッチによって型の等式 x2 ~ y1 が導入される。

結果として、 x :: Tensor [_, z], y :: Tensor [z, _] という形になるので、 matmul x y は型検査が通る。

ここでもしこの3行をコメントアウトすると、以下ような型エラーとなる。

test.hs:36:29: error:

• Couldn't match type ‘shape’ with ‘'[a0, b0]’
‘shape’ is a rigid type variable bound by
a pattern with constructor:
DynShape :: forall (shape :: Shape) (f :: Shape -> *).
KnownShape shape =>
f shape -> DynShape f,
in a pattern binding in
a 'do' block
at test.hs:28:3-12
Expected type: Tensor '[a0, b0]
Actual type: Tensor shape
• In the first argument of ‘matmul’, namely ‘x’
In the first argument of ‘DynShape’, namely ‘(matmul x y)’
In the second argument of ‘($)’, namely ‘DynShape (matmul x y)’
• Relevant bindings include
x :: Tensor shape (bound at test.hs:28:12)
|
36 | return $ DynShape (matmul x y)
| ^

test.hs:36:31: error:
• Couldn't match type ‘shape1’ with ‘'[b0, c0]’
‘shape1’ is a rigid type variable bound by
a pattern with constructor:
DynShape :: forall (shape :: Shape) (f :: Shape -> *).
KnownShape shape =>
f shape -> DynShape f,
in a pattern binding in
a 'do' block
at test.hs:29:3-12
Expected type: Tensor '[b0, c0]
Actual type: Tensor shape1
• In the second argument of ‘matmul’, namely ‘y’
In the first argument of ‘DynShape’, namely ‘(matmul x y)’
In the second argument of ‘($)’, namely ‘DynShape (matmul x y)’
• Relevant bindings include
y :: Tensor shape1 (bound at test.hs:29:12)
|
36 | return $ DynShape (matmul x y)
| ^

また、実行してみるとこの main は普通に実行できるが、 yy = x に書き換えた main' を用意すると、こちらはちゃんと実行時にエラーになる。

main' :: IO (DynShape Tensor)

main' = do
DynShape x <- load "tensor1.t"
let y = x

-- If we comment out following three lines, it fails to type check.
IsLength2 _ x2 <- assertIsLength2 (shape x)
IsLength2 y1 _ <- assertIsLength2 (shape y)
Refl <- assertEqNat x2 y1

return $ DynShape (matmul x y)

*Main> main

*Main> main'
*** Exception: 3 /= 2
CallStack (from HasCallStack):
error, called at test.hs:122:16 in main:Main


感想


  • これだけなら実現できるけど、実際に使おうと思うと、型レベルと項レベルの区別をあまり意識したくないのと、あと対話的なアサーションの挿入がないと普通の人には厳しいか。

  • こういうことをやろうとすると、 パターンマッチしないとスコーレム変数や制約をスコープに導入できないのが、ちょっと面倒くさい。構文糖(syntax sugar)があれば良い話ではあるけれど、その辺り勝手に導入されるような言語があっても面白いかも知れない。


  • Proxy は鬱陶しいので Type Applications in Patterns が早く導入されて欲しい……