Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
518 views
in Technique[技术] by (71.8m points)

recursion - Merge sort from "Programming Scala" causes stack overflow

A direct cut and paste of the following algorithm:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

causes a StackOverflowError on 5000 long lists.

Is there any way to optimize this so that this doesn't occur?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.

The latter solution goes like this:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...