Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
814 views
in Technique[技术] by (71.8m points)

performance - Laziness and tail recursion in Haskell, why is this crashing?

I have this fairly simple function to compute the mean of elements of a big list, using two accumulators to hold the sum so far and the count so far:

mean = go 0 0
    where
      go s l []     = s / fromIntegral l
      go s l (x:xs) = go (s+x) (l+1) xs

main = do
  putStrLn (show (mean [0..10000000]))

Now, in a strict language, this would be tail-recursive, and there would be no problem. However, as Haskell is lazy, my googling has led me to understand that (s+x) and (l+1) will be passed down the recursion as thunks. So this whole thing crashes and burns:

Stack space overflow: current size 8388608 bytes.

After further googling, I found seq and $!. Which it seems I don't understand because all my attempts at using them in this context proved futile, with error messages saying something about infinite types.

Finally I found -XBangPatterns, which solves it all by changing the recursive call:

go !s !l (x:xs) = go (s+x) (l+1) xs

But I'm not happy with this, as -XBangPatterns is currently an extension. I would like to know how to make the evaluation strict without the use of -XBangPatterns. (And maybe learn something too!)

Just so you understand my lack of understanding, here's what I tried (the only try that compiled, that is):

go s l (x:xs) = go (seq s (s+x)) (seq l (l+1)) xs

From what I could understand, seq should here force the evaluation of the s and l argument, thus avoiding the problem caused by thunks. But I still get a stack overflow.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I've written extensively on this:

Firstly, yes, if you want to require strict evaluation of the accumulators use seq and stay in Haskell 98:

mean = go 0 0
  where
    go s l []     = s / fromIntegral l
    go s l (x:xs) = s `seq` l `seq`
                      go (s+x) (l+1) xs

main = print $ mean [0..10000000]

*Main> main
5000000.0

Secondly: strictness analysis will kick in if you give some type annotations, and compile with -O2:

mean :: [Double] -> Double
mean = go 0 0
 where
  go :: Double -> Int -> [Double] -> Double
  go s l []     = s / fromIntegral l
  go s l (x:xs) = go (s+x) (l+1) xs

main = print $ mean [0..10000000]

$ ghc -O2 --make A.hs
[1 of 1] Compiling Main             ( A.hs, A.o )
Linking A ...

$ time ./A
5000000.0
./A  0.46s user 0.01s system 99% cpu 0.470 total

Because 'Double' is a wrapper over the strict atomic type Double#, with optimizations on, and a precise type, GHC runs strictness analysis and infers that the strict version will be ok.

import Data.Array.Vector

main = print (mean (enumFromToFracU 1 10000000))

data Pair = Pair !Int !Double

mean :: UArr Double -> Double   
mean xs = s / fromIntegral n
  where
    Pair n s       = foldlU k (Pair 0 0) xs
    k (Pair n s) x = Pair (n+1) (s+x)

$ ghc -O2 --make A.hs -funbox-strict-fields
[1 of 1] Compiling Main             ( A.hs, A.o )
Linking A ...

$ time ./A
5000000.5
./A  0.03s user 0.00s system 96% cpu 0.038 total

As described in the RWH chapter above.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...