『ゼロから作るDeep Learning – Pythonで学ぶディープラーニングの理論と実装』 の MNIST を扱うための サンプルコードを Haskell で実装し直してみました。
実装
module Mnist
( loadMnist
) where
import Control.Monad
import Numeric.LinearAlgebra
import Network.HTTP.Simple (parseRequest, httpLBS, getResponseBody)
import System.Directory (doesFileExist)
import qualified Data.ByteString.Lazy as BL
import qualified Codec.Compression.GZip as GZ (compress, decompress)
import Data.Binary (encode, decode)
type DataSet = (Matrix R, Matrix R)
baseUrl = "http://yann.lecun.com/exdb/mnist"
keyFiles = [
("train_img", "train-images-idx3-ubyte.gz"),
("train_label", "train-labels-idx1-ubyte.gz"),
("test_img", "t10k-images-idx3-ubyte.gz"),
("test_label", "t10k-labels-idx1-ubyte.gz")
]
assetsDir = "assets"
pickleFile = "mnist.dat"
imgSize = 784
generatePath :: String -> String
generatePath p = assetsDir ++ "/" ++ p
download :: String -> IO ()
download fileName = do
let savePath = generatePath fileName
e <- doesFileExist savePath
unless e $ do
putStrLn $ "Downloading " ++ fileName ++ " ..."
res <- httpLBS =<< parseRequest (baseUrl ++ "/" ++ fileName)
BL.writeFile savePath (getResponseBody res)
putStrLn "Done"
downloadMnist :: [(String, String)] -> IO ()
downloadMnist [] = return ()
downloadMnist (x:xs) = do
download $ snd x
downloadMnist xs
toDoubleList :: BL.ByteString -> [Double]
toDoubleList = map (read . show . fromEnum) . BL.unpack
loadLabel :: String -> IO (Matrix R)
loadLabel fileName = do
contents <- fmap GZ.decompress (BL.readFile $ generatePath fileName)
return . matrix 1 . toDoubleList $ BL.drop 8 contents
loadImg :: String -> IO (Matrix R)
loadImg fileName = do
contents <- fmap GZ.decompress (BL.readFile $ generatePath fileName)
return . matrix imgSize . toDoubleList $ BL.drop 16 contents
toMatrix :: IO [DataSet]
toMatrix = do
trainImg <- loadImg . snd $ keyFiles !! 0
trainLabel <- loadLabel . snd $ keyFiles !! 1
testImg <- loadImg . snd $ keyFiles !! 2
testLabel <- loadLabel . snd $ keyFiles !! 3
return [(trainImg, trainLabel), (testImg, testLabel)]
createPickle :: String -> [DataSet] -> IO ()
createPickle p ds = BL.writeFile p $ (GZ.compress . encode) ds
loadPickle :: String -> IO [DataSet]
loadPickle p = do
encodeDs <- BL.readFile p
return $ (decode . GZ.decompress) encodeDs
initMnist :: IO ()
initMnist = do
downloadMnist keyFiles
putStrLn "Creating binary Matrix file ..."
createPickle (generatePath pickleFile) =<< toMatrix
putStrLn "Done"
normalizeImg :: Bool -> [DataSet] -> IO [DataSet]
normalizeImg f ds@[train, test]
| f = return [ ((/255) $ fst train, snd train), ((/255) $ fst test, snd test) ]
| otherwise = return ds
loadMnist :: Bool -> IO [DataSet]
loadMnist normalize = do
let loadPath = generatePath pickleFile
e <- doesFileExist loadPath
unless e initMnist
loadPickle loadPath >>= normalizeImg normalize
実行結果
stack ghci
で実行してみます。
*Main> :l Mnist.hs
*Mnist> ds <- loadMnist False
*Mnist> size . fst . (!!0) $ ds
(60000,784)
*Mnist> size . snd . (!!0) $ ds
(60000,1)
*Mnist> size . fst . (!!1) $ ds
(10000,784)
*Mnist> size . snd . (!!1) $ ds
(10000,1)
できてそうです。
所感
行列への変換より、落としてきた MNIST をどうやって保存・復元しようかにすごい時間をかけていた気がします。( hmatrix の Vector や Matrix が Binary のインスタンスになってて本当によかった。これを自前で実装するのは流石につらい...)
あとで実装の詳しい解説も書こうと思います。
(書きました→ http://ku00.hatenablog.com/entry/2017/04/29/201230 )
実装したコードはこちらに置いてます。