HaskellでMNISTを使えるようにする

  • 6
    Like
  • 0
    Comment

『ゼロから作るDeep Learning – Pythonで学ぶディープラーニングの理論と実装』 の MNIST を扱うための サンプルコードを Haskell で実装し直してみました。

実装

module Mnist
    ( loadMnist
    ) where

import Control.Monad
import Numeric.LinearAlgebra
import Network.HTTP.Simple (parseRequest, httpLBS, getResponseBody)
import System.Directory (doesFileExist)
import qualified Data.ByteString.Lazy as BL
import qualified Codec.Compression.GZip as GZ (compress, decompress)
import Data.Binary (encode, decode)

type DataSet = (Matrix R, Matrix R)

baseUrl = "http://yann.lecun.com/exdb/mnist"
keyFiles = [
            ("train_img", "train-images-idx3-ubyte.gz"),
            ("train_label", "train-labels-idx1-ubyte.gz"),
            ("test_img", "t10k-images-idx3-ubyte.gz"),
            ("test_label", "t10k-labels-idx1-ubyte.gz")
          ]

assetsDir = "assets"
pickleFile = "mnist.dat"
imgSize = 784

generatePath :: String -> String
generatePath p = assetsDir ++ "/" ++ p

download :: String -> IO ()
download fileName = do
    let savePath = generatePath fileName

    e <- doesFileExist savePath
    unless e $ do
      putStrLn $ "Downloading " ++ fileName ++ " ..."
      res <- httpLBS =<< parseRequest (baseUrl ++ "/" ++ fileName)
      BL.writeFile savePath (getResponseBody res)
      putStrLn "Done"

downloadMnist :: [(String, String)] -> IO ()
downloadMnist [] = return ()
downloadMnist (x:xs) = do
    download $ snd x
    downloadMnist xs

toDoubleList :: BL.ByteString -> [Double]
toDoubleList = map (read . show . fromEnum) . BL.unpack

loadLabel :: String -> IO (Matrix R)
loadLabel fileName = do
    contents <- fmap GZ.decompress (BL.readFile $ generatePath fileName)
    return . matrix 1 . toDoubleList $ BL.drop 8 contents

loadImg :: String -> IO (Matrix R)
loadImg fileName = do
    contents <- fmap GZ.decompress (BL.readFile $ generatePath fileName)
    return . matrix imgSize . toDoubleList $ BL.drop 16 contents

toMatrix :: IO [DataSet]
toMatrix = do
    trainImg <- loadImg . snd $ keyFiles !! 0
    trainLabel <- loadLabel . snd $ keyFiles !! 1
    testImg <- loadImg . snd $ keyFiles !! 2
    testLabel <- loadLabel . snd $ keyFiles !! 3

    return [(trainImg, trainLabel), (testImg, testLabel)]

createPickle :: String -> [DataSet] -> IO ()
createPickle p ds = BL.writeFile p $ (GZ.compress . encode) ds

loadPickle :: String -> IO [DataSet]
loadPickle p = do
    encodeDs <- BL.readFile p
    return $ (decode . GZ.decompress) encodeDs

initMnist :: IO ()
initMnist = do
    downloadMnist keyFiles
    putStrLn "Creating binary Matrix file ..."
    createPickle (generatePath pickleFile) =<< toMatrix
    putStrLn "Done"

normalizeImg :: Bool -> [DataSet] -> IO [DataSet]
normalizeImg f ds@[train, test]
    | f         = return [ ((/255) $ fst train, snd train), ((/255) $ fst test, snd test) ]
    | otherwise = return ds

loadMnist :: Bool -> IO [DataSet]
loadMnist normalize = do
    let loadPath = generatePath pickleFile
    e <- doesFileExist loadPath
    unless e initMnist

    loadPickle loadPath >>= normalizeImg normalize

実行結果

stack ghci で実行してみます。

*Main> :l Mnist.hs
*Mnist> ds <- loadMnist False
*Mnist> size . fst . (!!0) $ ds
(60000,784)
*Mnist> size . snd . (!!0) $ ds
(60000,1)
*Mnist> size . fst . (!!1) $ ds
(10000,784)
*Mnist> size . snd . (!!1) $ ds
(10000,1)

できてそうです。

所感

行列への変換より、落としてきた MNIST をどうやって保存・復元しようかにすごい時間をかけていた気がします。( hmatrix の Vector や Matrix が Binary のインスタンスになってて本当によかった。これを自前で実装するのは流石につらい...)

あとで実装の詳しい解説も書こうと思います。
(書きました→ http://ku00.hatenablog.com/entry/2017/04/29/201230 )

実装したコードはこちらに置いてます。

参考文献