これは何?
本Advent Calendarの8日目の記事です。
本記事では、Haskellで競技プログラミングの問題を解く際によく使われる高階関数である、foldrとfoldlを紹介します。
そもそもfoldとは畳み込み(reduce)である
上記で説明されていることをかいつまんで説明する。
リストは空もしくは、あるxと別のxのリストを結合したものである
[1,2,3]というリストはHaskellでは内部的に以下のように表されている。
ghci> (:) 1 ((:) 2 ((:) 3 []))
[1,2,3]
このようなリストの合計sumを求めるには以下のように再帰を使うことで実現できる。
sumList :: Num a => [a] -> a
sumList [] = 0
sumList (x:xs) = x + sumList xs
このことは、sumの計算は一般的な再帰パターンと箱のなかの部分とを貼り合せることでモジュール化できることを意味している。この再帰パターンは慣例と しては reduce と呼ばれている。
reduceをわかりやすく理解するため、同参考文献では、リストの定義からの置き換えを使って説明されている。
(:) 1 ((:) 2 ((:) 3 []))
この空のリストをaという初期値、:の部分を関数に置き換えることがreduceである。
sumの場合には以下のようになる。
add 1 (add 2 (add 3 0)) -- 6
このreduceを行う関数の代表としてHaskellにはfoldlとfoldrが存在する。
Haskell Wikiにあった図と同じようなもの↓
foldr(右再帰)
(:)
/ \
1 (:)
/ \
2 (:)
/ \
3 (:)
/ \
4 []
(+)
/ \
1 (+)
/ \
2 (+)
/ \
3 4
foldl(左再帰または、末尾再帰とも)
(:)
/ \
1 (:)
/ \
2 (:)
/ \
3 (:)
/ \
4 []
(+)
/ \
(+) 4
/ \
(+) 3
/ \
(+) 2
/ \
0 1
foldl、foldrはFoldable型クラスのインスタンスである
ghciを使ってfoldlの型を確認してみると型コンストラクタがFoldable型クラスのインスタンスである必要があるというのがわかる。
ghci> :t foldl
foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b
-
Foldableクラスは要素を一度に1つずつ要約値へと畳み込む操作1を提供するためのの型クラスである。
- Foldableは型引数を1つ取る型コンストラクタ(kind が Type -> Type のもの1に対して定義される。
代表的なインスタンスは、Maybe や [] のように、型を1つ受け取って新しい型を作る型コンストラクタである
型引数とはMaybeや[]のように型を1つ受け取って新しい型を作る型コンストラクタに対して渡される値の型が型引数である。
e.g. [Char]の場合は型コンストラクタは[]であり、型引数はChar
foldlとfoldrを実際に使ってみる
main :: IO ()
main = do
print $ foldl (+) 5 [1, 2, 3, 4] -- 15
print $ foldr (+) 5 [1, 2, 3, 4] -- 15
どちらも結果は15になる。
これは自前でfoldrとfoldlを実装してみるとわかりやすい。
foldrとfoldlはそれぞれ要素を右と左から畳み込むという違いがある。
foldr → 右畳み込み → 結果を右側から構築
foldl → 左畳み込み → 結果を左側から構築
foldr' :: (a -> b -> b) -> b -> [a] -> b
foldr' _ e [] = e -- 初期値
foldr' f e (x : xs) = f x (foldr' f e xs) -- 受け取る引数をlist構築式の形で処理している。:はリストを作る演算子である。
-- e.g. foldr' (+) 5 [1, 2, 3, 4]の場合
-- foldr' (+) 5 (1: [2, 3, 4])
-- = 1 + (foldr' 5 [2, 3, 4])
-- = 1 + 2 + (foldr' 5 [3, 4])
-- = 1 + 2 + 3 + (foldr' 5 [4])
-- = 1 + 2 + 3 + 4 + (foldr' 5 [])
-- = 1 + 2 + 3 + 4 + 5
foldl' :: (b -> a -> b) -> b -> [a] -> b
foldl' _ acc [] = acc -- 累積値
foldl' f acc (x : xs) = foldl' f (f acc x) xs
-- e.g. foldl' (+) 5 [1, 2, 3, 4]の場合
-- = foldl' (+) 5 (1: [2, 3, 4])
-- = foldl' (+) ((+) 5 1) [2, 3, 4]
-- = foldl' (+) (6) [2, 3, 4]
-- = foldl' (+) ((+) 6 2) [3, 4]
-- = foldl' (+) (8) [3, 4]
-- = foldl' (+) ((+) 8 3) [4]
-- = foldl' (+) (11) [4]
-- = foldl' (+) ((+) 11 4) []
-- = foldl' (+) (15) []
-- = 15
main :: IO ()
main = do
print $ foldr' (+) 5 [1, 2, 3, 4]
print $ foldl' (+) 5 [1, 2, 3, 4]
foldlのメモリ効率が良く実装できる例reverse
試しにreverseというリストを逆順にして返す関数を考えてみる。
module Reverse (reverse', reverse'', reverse''') where
-- 再帰バージョン
reverse' :: [a] -> [a]
reverse' [] = []
reverse' (a : as) = reverse' as ++ [a] -- 配列の先頭を取り出して、末尾に連結する
-- foldlバージョン
reverse'' :: [a] -> [a]
reverse'' as = foldl (flip (:)) [] as
-- NOTE: flipは関数に渡す引数の順序を入れ替えて返す関数である。
-- ghci> (:) 1 [2,3]
-- [1,2,3]
--- ghci> flip (:) [2,3] 1
-- [1,2,3]
-- e.g. foldl (flip (:)) [] [1,2,3]の場合
-- foldl (flip (:)) ((flip (:)) [] 1) [2,3]
-- = foldl (flip (:)) ([1]) [2,3]
-- = foldl (flip (:)) ((flip (:)) [1] 2) [3]
-- = foldl (flip (:)) ([2, 1]) [3]
-- = foldl (flip (:)) ((flip (:)) [2, 1] 3) []
-- = foldl (flip (:)) ([3, 2, 1]) []
-- = [3, 2, 1]
-- foldrバージョン
reverse''' :: [a] -> [a]
reverse''' as = foldr (\x acc -> acc ++ [x]) [] as -- foldrにより末尾から順に値を取り出し、accに結合することで逆順にする。
-- e.g. foldr (\x acc -> acc ++ [x]) [] [1, 2, 3]の場合
-- (\x [] -> [] ++ [x]) 1 (foldr (\x acc -> acc ++ [x]) [] [2, 3])
-- (\x [] -> [] ++ [x]) 1 ((\x [] -> [] ++ [x]) 2 (foldr (\x acc -> acc ++ [x]) [] [3]))
-- = (\x acc -> acc ++ [x]) 1 ((\x acc -> acc ++ [x]) 2 ((\x acc -> acc ++ [x]) 3 (foldr (\x acc -> acc ++ [x]) [] [])))
-- = (\x acc -> acc ++ [x]) 1 ((\x acc -> acc ++ [x]) 2 ((\x acc -> acc ++ [x]) 3 [])
-- = (\x acc -> acc ++ [x]) 1 (\x acc -> acc ++ [x]) 2 [3]
-- = (\x acc -> acc ++ [x]) 1 ([3, 2])
-- = [3, 2, 1]
main :: IO ()
main = do
print $ reverse' [1, 2, 3, 4] -- [4, 3, 2, 1]
print $ reverse' "Hello, World!" -- "!dlroW ,olleH"
print $ reverse' ["MIKE", "KEN"] -- ["KEN","MIKE"]
print $ reverse'' [1, 2, 3, 4] -- [4, 3, 2, 1]
print $ reverse''' [1, 2, 3, 4] -- [4, 3, 2, 1]
この3つの実装の中で最適なのはfoldlを使う実装である。
再帰を使う場合 O(n^2)
再帰を使う実装の場合++を使っているが、++は左側引数のリストを全走査してコピーするので
$$
\sum_{k=0}^{n-1} k
$$
となり、$O(n^2)$となる
foldlを使う場合 O(n)
flipを使う書き方にクセがあるが、各操作が$O(1)$で実施できるため、計算量は$O(n)$となる。
foldrを使う場合 O(n^2)
再帰同様に++を使うため、$O(n^2)$となる。
foldrを使うと無限リストが扱える
-- 無限リスト
nums :: [Int]
nums = [1 ..]
-- 先頭の偶数を合計する(1つでも奇数が出たら終了)foldr版
sumEvenUntilOddr :: Int
sumEvenUntilOddr = foldr step 0 nums
where
step x acc
| even x = x + acc
| otherwise = 0 -- 奇数が出たら fold を終了
-- foldl版
sumEvenUntilOddl :: Int
sumEvenUntilOddl = foldl step 0 nums
where
step x acc
| even x = x + acc
| otherwise = 0 -- 奇数が出たら fold を終了
main :: IO ()
main = do
print sumEvenUntilOddr
print sumEvenUntilOddl -- 終了しない
このような無限リストを使い、途中で打ち切る場合にはfoldlはメモリ効率を優先するが、左から畳み込まれる関係上、無限リストが終了しないため使用できない。
一方foldrであれば要素が右から畳み込まれるため、途中で中断される処理に対して無限リストを適用する場合には、Haskellの遅延評価の恩恵を受けられ、結果がかえる。
AtCoderの問題に使用してみる
この問題では、与えられた文字列に対して1文字のものと2文字のものを判別する必要がある。
文字列が各文字を何文字含んだマップを作る部分でfoldlを使用している。
import qualified Data.Map as Map
-- 文字列を受け取り、各文字が何文字含まれているかのMapを返す。
countChars :: String -> Map.Map Char Int
countChars s = foldl (\m c -> Map.insertWith (+) c 1 m) Map.empty s
findSingleChar :: String -> Char
findSingleChar s = fst $ head $ filter (\(c, n) -> n == 1) $ Map.toList $ countChars s
main :: IO ()
main = do
s <- getLine
let c = findSingleChar s
putStrLn [c]
これは、文字列(HaskellではString型は[Char])に含まれる文字全てに対して関数を適用しているのだが、foldlを使うことで以下のような再帰をシンプルに表現できる。
countChars :: String -> Map.Map Char Int
countChars = go Map.empty
where
go m [] = m
go m (c:cs) = go (Map.insertWith (+) c 1 m) cs
この問題の場合、もっと簡単なアルゴリズムがあるがご愛嬌。
oddChar :: String -> Char
oddChar s
| s!!0 == s!!1 = head (filter (/= s!!0) s)
| s!!0 == s!!2 = s!!1
| otherwise = s!!0
main :: IO ()
main = getLine >>= putStrLn . pure . oddChar
追記: 最近解いたやつでもう少し良い例があったので記載
カードの強さを競うゲームで、勝敗に応じてポイントが付与されるのでその最終結果を計算する問題。
初期値[0,0]に対してduel関数をcardsを引数にして呼び出している。
duel :: [Int] -> [String] -> [Int]
duel [taroPoint, hanakoPoint] [taroCard, hanakoCard]
| taroCard > hanakoCard = [taroPoint + 3, hanakoPoint]
| taroCard == hanakoCard = [taroPoint + 1, hanakoPoint + 1]
| otherwise = [taroPoint, hanakoPoint + 3]
solve :: [[String]] -> [Int]
solve cards = foldl duel [0, 0] cards
main :: IO ()
main = interact $ \inputs ->
let ls = lines inputs
-- _ = read (head ls) :: Int
cards = map words $ tail ls
in (unwords . map show $ solve cards) ++ "\n"
この問題では、1〜無限のベクトルについて距離をもとめる問題だったので、foldlを使ってまとめて計算した。
import Text.Printf (printf)
solve :: [Double] -> [Double] -> [Double]
solve xs ys = foldl (\acc p -> acc ++ [distance xs ys p]) [] [1, 2, 3, 8] -- foldlの第一、第二引数は固定にして第三引数だけを変えて折りたたむ
distance :: [Double] -> [Double] -> Int -> Double
distance xs ys p
| p == 8 =
maximum
[ abs (xs !! i - ys !! i)
| i <- [0 .. length xs - 1]
]
| otherwise =
let tmp =
sum $
[ abs $ (xs !! i - ys !! i) ^ p
| i <- [0 .. length xs - 1]
]
in tmp ** (1 / fromIntegral p)
main :: IO ()
main = interact $ \inputs ->
let ls = lines inputs
xs = map read . words $ ls !! 1 :: [Double]
ys = map read . words $ ls !! 2 :: [Double]
in unlines $ map (printf "%.8f") $ solve xs ys