これは何?
本Advent Calendarの8日目の記事です。
本記事では、Haskellで競技プログラミングの問題を解く際によく使われる高階関数である、foldrとfoldlを紹介します。
foldl、foldrはFoldable型クラスのインスタンスである
ghciを使ってfoldlの型を確認してみると型コンストラクタがFoldable型クラスのインスタンスである必要があるというのがわかる。
ghci> :t foldl
foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b
Foldableクラスは要素を一度に1つずつ要約値へと畳み込む操作1を提供するためのの型クラスである。
Foldableは型引数を1つ取る型コンストラクタ(kind が Type -> Type のもの1に対して定義される。
代表的なインスタンスは、Maybe や [] のように、型を1つ受け取って新しい型を作る型コンストラクタである
型引数とはMaybeや[]のように型を1つ受け取って新しい型を作る型コンストラクタに対して渡される値の型が型引数である。
e.g. [Char]の場合は型コンストラクタは[]であり、型引数はChar
foldlとfoldrを実際に使ってみる
main :: IO ()
main = do
print $ foldl (+) 5 [1, 2, 3, 4] -- 15
print $ foldr (+) 5 [1, 2, 3, 4] -- 15
どちらも結果は15になる。
これは自前でfoldrとfoldlを実装してみるとわかりやすい。
foldrとfoldlはそれぞれ要素を右と左から畳み込むという違いがある。
foldr → 右畳み込み → 結果を右側から構築
foldl → 左畳み込み → 結果を左側から構築
foldr' :: (a -> b -> b) -> b -> [a] -> b
foldr' _ e [] = e -- 初期値
foldr' f e (x : xs) = f x (foldr' f e xs) -- 受け取る引数をlist構築式の形で処理している。:はリストを作る演算子である。
-- e.g. foldr' (+) 5 [1, 2, 3, 4]の場合
-- foldr' (+) 5 (1: [2, 3, 4])
-- = 1 + (foldr' 5 [2, 3, 4])
-- = 1 + 2 + (foldr' 5 [3, 4])
-- = 1 + 2 + 3 + (foldr' 5 [4])
-- = 1 + 2 + 3 + 4 + (foldr' 5 [])
-- = 1 + 2 + 3 + 4 + 5
foldl' :: (b -> a -> b) -> b -> [a] -> b
foldl' _ acc [] = acc -- 累積値
foldl' f acc (x : xs) = foldl' f (f acc x) xs
-- e.g. foldl' (+) 5 [1, 2, 3, 4]の場合
-- = foldl' (+) 5 (1: [2, 3, 4])
-- = foldl' (+) ((+) 5 1) [2, 3, 4]
-- = foldl' (+) (6) [2, 3, 4]
-- = foldl' (+) ((+) 6 2) [3, 4]
-- = foldl' (+) (8) [3, 4]
-- = foldl' (+) ((+) 8 3) [4]
-- = foldl' (+) (11) [4]
-- = foldl' (+) ((+) 11 4) []
-- = foldl' (+) (15) []
-- = 15
main :: IO ()
main = do
print $ foldr' (+) 5 [1, 2, 3, 4]
print $ foldl' (+) 5 [1, 2, 3, 4]
使い分け: foldlはメモリ効率重視foldrは無限リスト等途中で打ち切られる処理に使う
Strict left-associative folds are a good fit for space-efficient reduction, while lazy right-associative folds are a good fit for corecursive iteration, or for folds that short-circuit after processing an initial subsequence of the structure's elements.1
foldlのメモリ効率が良い例reverse
試しにreverseというリストを逆順にして返す関数を考えてみる。
module Reverse (reverse', reverse'', reverse''') where
-- 再帰バージョン
reverse' :: [a] -> [a]
reverse' [] = []
reverse' (a : as) = reverse' as ++ [a] -- 配列の先頭を取り出して、末尾に連結する
-- foldlバージョン
reverse'' :: [a] -> [a]
reverse'' as = foldl (flip (:)) [] as
-- NOTE: flipは関数に渡す引数の順序を入れ替えて返す関数である。
-- ghci> (:) 1 [2,3]
-- [1,2,3]
--- ghci> flip (:) [2,3] 1
-- [1,2,3]
-- e.g. foldl (flip (:)) [] [1,2,3]の場合
-- foldl (flip (:)) ((flip (:)) [] 1) [2,3]
-- = foldl (flip (:)) ([1]) [2,3]
-- = foldl (flip (:)) ((flip (:)) [1] 2) [3]
-- = foldl (flip (:)) ([2, 1]) [3]
-- = foldl (flip (:)) ((flip (:)) [2, 1] 3) []
-- = foldl (flip (:)) ([3, 2, 1]) []
-- = [3, 2, 1]
-- foldrバージョン
reverse''' :: [a] -> [a]
reverse''' as = foldr (\x acc -> acc ++ [x]) [] as -- foldrにより末尾から順に値を取り出し、accに結合することで逆順にする。
-- e.g. foldr (\x acc -> acc ++ [x]) [] [1, 2, 3]の場合
-- (\x [] -> [] ++ [x]) 1 (foldr (\x acc -> acc ++ [x]) [] [2, 3])
-- (\x [] -> [] ++ [x]) 1 ((\x [] -> [] ++ [x]) 2 (foldr (\x acc -> acc ++ [x]) [] [3]))
-- = (\x acc -> acc ++ [x]) 1 ((\x acc -> acc ++ [x]) 2 ((\x acc -> acc ++ [x]) 3 (foldr (\x acc -> acc ++ [x]) [] [])))
-- = (\x acc -> acc ++ [x]) 1 ((\x acc -> acc ++ [x]) 2 ((\x acc -> acc ++ [x]) 3 [])
-- = (\x acc -> acc ++ [x]) 1 (\x acc -> acc ++ [x]) 2 [3]
-- = (\x acc -> acc ++ [x]) 1 ([3, 2])
-- = [3, 2, 1]
main :: IO ()
main = do
print $ reverse' [1, 2, 3, 4] -- [4, 3, 2, 1]
print $ reverse' "Hello, World!" -- "!dlroW ,olleH"
print $ reverse' ["MIKE", "KEN"] -- ["KEN","MIKE"]
print $ reverse'' [1, 2, 3, 4] -- [4, 3, 2, 1]
print $ reverse''' [1, 2, 3, 4] -- [4, 3, 2, 1]
この3つの実装の中で最適なのはfoldlを使う実装である。
再帰を使う場合 O(n^2)
再帰を使う実装の場合++を使っているが、++は左側引数のリストを全走査してコピーするので
$$
\sum_{k=0}^{n-1} k
$$
となり、$O(n^2)$となる
foldlを使う場合 O(n)
flipを使う書き方にクセがあるが、各操作が$O(1)$で実施できるため、計算量は$O(n)$となる。
foldrを使う場合 O(n^2)
再帰同様に++を使うため、$O(n^2)$となる。
foldrを使うと無限リストが扱える
-- 無限リスト
nums :: [Int]
nums = [1 ..]
-- 先頭の偶数を合計する(1つでも奇数が出たら終了)foldr版
sumEvenUntilOddr :: Int
sumEvenUntilOddr = foldr step 0 nums
where
step x acc
| even x = x + acc
| otherwise = 0 -- 奇数が出たら fold を終了
-- foldl版
sumEvenUntilOddl :: Int
sumEvenUntilOddl = foldl step 0 nums
where
step x acc
| even x = x + acc
| otherwise = 0 -- 奇数が出たら fold を終了
main :: IO ()
main = do
print sumEvenUntilOddr
print sumEvenUntilOddl -- 終了しない
このような無限リストを使い、途中で打ち切る場合にはfoldlはメモリ効率を優先するが、左から畳み込まれる関係上、無限リストが終了しないため使用できない。
一方foldrであれば要素が右から畳み込まれるため、途中で中断される処理に対して無限リストを適用する場合には、Haskellの遅延評価の恩恵を受けられ、結果がかえる。
実際に問題に使用してみる
この問題では、与えられた文字列に対して1文字のものと2文字のものを判別する必要がある。
文字列が各文字を何文字含んだマップを作る部分でfoldlを使用している。
import qualified Data.Map as Map
-- 文字列を受け取り、各文字が何文字含まれているかのMapを返す。
countChars :: String -> Map.Map Char Int
countChars s = foldl (\m c -> Map.insertWith (+) c 1 m) Map.empty s
findSingleChar :: String -> Char
findSingleChar s = fst $ head $ filter (\(c, n) -> n == 1) $ Map.toList $ countChars s
main :: IO ()
main = do
s <- getLine
let c = findSingleChar s
putStrLn [c]
これは、文字列(HaskellではString型は[Char])に含まれる文字全てに対して関数を適用しているのだが、foldlを使うことで以下のような再帰をシンプルに表現できる。
countChars :: String -> Map.Map Char Int
countChars = go Map.empty
where
go m [] = m
go m (c:cs) = go (Map.insertWith (+) c 1 m) cs
この問題の場合、もっと簡単なアルゴリズムがあるがご愛嬌。
oddChar :: String -> Char
oddChar s
| s!!0 == s!!1 = head (filter (/= s!!0) s)
| s!!0 == s!!2 = s!!1
| otherwise = s!!0
main :: IO ()
main = getLine >>= putStrLn . pure . oddChar