在非尾递归函数中使用Scala内部函数时是否存在效率损失?

时间:2009-07-21 14:27:35

标签: scala

我对Scala相当陌生,我仍然在努力开发一种方法,这种方法有效,哪些方法可能包含隐藏的性能成本。

如果我定义一个包含内部函数的(非尾部)递归函数。是否为每个递归调用实例化了内部函数的功能对象的多个副本?

例如以下内容:

def sumDoubles(n: Int): Int = {
    def dbl(a: Int) = 2 * a;
    if(n > 0)
        dbl(n) + sumDoubles(n - 1)
    else
        0               
}

...对于dbl的调用,堆栈中是否存在sumDoubles(15)个对象的副本数量?

3 个答案:

答案 0 :(得分:24)

在字节码级别

def sumDoubles(n: Int): Int = {
  def dbl(a: Int) = 2 * a;
  if(n > 0)
    dbl(n) + sumDoubles(n - 1)
  else
    0               
}

完全相同
private[this] def dbl(a: Int) = 2 * a;

def sumDoubles(n: Int): Int = {
  if(n > 0)
    dbl(n) + sumDoubles(n - 1)
  else
    0               
}

但是不要相信我的话

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial   #10; //Method java/lang/Object."":()V
   4:   return

private final int dbl$1(int);
  Code:
   0:   iconst_2
   1:   iload_1
   2:   imul
   3:   ireturn

public int sumDoubles(int);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple   21
   5:   aload_0
   6:   iload_1
   7:   invokespecial   #22; //Method dbl$1:(I)I
   10:  aload_0
   11:  iload_1
   12:  iconst_1
   13:  isub
   14:  invokevirtual   #24; //Method sumDoubles:(I)I
   17:  iadd
   18:  goto    22
   21:  iconst_0
   22:  ireturn

}

如果内部函数捕获不可变变量,那么就有翻译。这段代码

def foo(n: Int): Int = {
  def dbl(a: Int) = a * n;
    if(n > 0)
      dbl(n) + foo(n - 1)
    else
      0               
}

获取翻译成

private[this] def dbl(a: Int, n: Int) = a * n;

def foo(n: Int): Int = {
  if(n > 0)
    dbl(n, n) + foo(n - 1)
  else
    0               
}

同样,这些工具适合你

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial   #10; //Method java/lang/Object."":()V
   4:   return

private final int dbl$1(int, int);
  Code:
   0:   iload_1
   1:   iload_2
   2:   imul
   3:   ireturn

public int foo(int);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple   22
   5:   aload_0
   6:   iload_1
   7:   iload_1
   8:   invokespecial   #23; //Method dbl$1:(II)I
   11:  aload_0
   12:  iload_1
   13:  iconst_1
   14:  isub
   15:  invokevirtual   #25; //Method foo:(I)I
   18:  iadd
   19:  goto    23
   22:  iconst_0
   23:  ireturn

}

如果捕获了可变变量,那么它必须装箱,这可能更贵。

def bar(_n : Int) : Int = {
   var n = _n
   def subtract() = n = n - 1

   if (n > 0) {
      subtract
      n
   }
   else
      0
}

转换为类似

的内容
private[this] def subtract(n : IntRef]) = n.value = n.value - 1

def bar(_n : Int) : Int = {
   var n = _n
   if (n > 0) {
      val nRef = IntRef(n)
      subtract(nRef)
      n = nRef.get()
      n
   }
   else
      0
}
~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial   #10; //Method java/lang/Object."":()V
   4:   return

private final void subtract$1(scala.runtime.IntRef);
  Code:
   0:   aload_1
   1:   aload_1
   2:   getfield    #18; //Field scala/runtime/IntRef.elem:I
   5:   iconst_1
   6:   isub
   7:   putfield    #18; //Field scala/runtime/IntRef.elem:I
   10:  return

public int bar(int);
  Code:
   0:   new #14; //class scala/runtime/IntRef
   3:   dup
   4:   iload_1
   5:   invokespecial   #23; //Method scala/runtime/IntRef."":(I)V
   8:   astore_2
   9:   aload_2
   10:  getfield    #18; //Field scala/runtime/IntRef.elem:I
   13:  iconst_0
   14:  if_icmple   29
   17:  aload_0
   18:  aload_2
   19:  invokespecial   #27; //Method subtract$1:(Lscala/runtime/IntRef;)V
   22:  aload_2
   23:  getfield    #18; //Field scala/runtime/IntRef.elem:I
   26:  goto    30
   29:  iconst_0
   30:  ireturn

}

编辑:添加第一类函数

要获得对象分配,您需要以更一流的方式使用函数

def sumWithFunction(n : Int, f : Int => Int) : Int = {
  if(n > 0)
    f(n) + sumWithFunction(n - 1, f)
  else
    0               
}  

def sumDoubles(n: Int) : Int = {
  def dbl(a: Int) = 2 * a
  sumWithFunction(n, dbl)
}

这有点像

def sumWithFunction(n : Int, f : Int => Int) : Int = {
  if(n > 0)
    f(n) + sumWithFunction(n - 1, f)
  else
    0               
}  

private[this] def dbl(a: Int) = 2 * a

def sumDoubles(n: Int) : Int = {
  sumWithFunction(n, new Function0[Int,Int] {
    def apply(x : Int) = dbl(x)
  })
}

这是字节码

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial   #10; //Method java/lang/Object."":()V
   4:   return

public final int dbl$1(int);
  Code:
   0:   iconst_2
   1:   iload_1
   2:   imul
   3:   ireturn

public int sumDoubles(int);
  Code:
   0:   aload_0
   1:   iload_1
   2:   new #20; //class Foo$$anonfun$sumDoubles$1
   5:   dup
   6:   aload_0
   7:   invokespecial   #23; //Method Foo$$anonfun$sumDoubles$1."":(LFoo;)V
   10:  invokevirtual   #29; //Method sumWithFunction:(ILscala/Function1;)I
   13:  ireturn

public int sumWithFunction(int, scala.Function1);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple   30
   5:   aload_2
   6:   iload_1
   7:   invokestatic    #36; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
   10:  invokeinterface #42,  2; //InterfaceMethod scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;
   15:  invokestatic    #46; //Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
   18:  aload_0
   19:  iload_1
   20:  iconst_1
   21:  isub
   22:  aload_2
   23:  invokevirtual   #29; //Method sumWithFunction:(ILscala/Function1;)I
   26:  iadd
   27:  goto    31
   30:  iconst_0
   31:  ireturn

}

~/test$ javap -private -c "Foo\$\$anonfun\$sumDoubles\$1"
Compiled from "test.scala"
public final class Foo$$anonfun$sumDoubles$1 extends java.lang.Object implements scala.Function1,scala.ScalaObject,java.io.Serializable{
private final Foo $outer;

public Foo$$anonfun$sumDoubles$1(Foo);
  Code:
   0:   aload_1
   1:   ifnonnull   12
   4:   new #10; //class java/lang/NullPointerException
   7:   dup
   8:   invokespecial   #13; //Method java/lang/NullPointerException."":()V
   11:  athrow
   12:  aload_0
   13:  aload_1
   14:  putfield    #17; //Field $outer:LFoo;
   17:  aload_0
   18:  invokespecial   #20; //Method java/lang/Object."":()V
   21:  aload_0
   22:  invokestatic    #26; //Method scala/Function1$class.$init$:(Lscala/Function1;)V
   25:  return

public final java.lang.Object apply(java.lang.Object);
  Code:
   0:   aload_0
   1:   getfield    #17; //Field $outer:LFoo;
   4:   astore_2
   5:   aload_0
   6:   aload_1
   7:   invokestatic    #37; //Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
   10:  invokevirtual   #40; //Method apply:(I)I
   13:  invokestatic    #44; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
   16:  areturn

public final int apply(int);
  Code:
   0:   aload_0
   1:   getfield    #17; //Field $outer:LFoo;
   4:   astore_2
   5:   aload_0
   6:   getfield    #17; //Field $outer:LFoo;
   9:   iload_1
   10:  invokevirtual   #51; //Method Foo.dbl$1:(I)I
   13:  ireturn

public scala.Function1 andThen(scala.Function1);
  Code:
   0:   aload_0
   1:   aload_1
   2:   invokestatic    #56; //Method scala/Function1$class.andThen:(Lscala/Function1;Lscala/Function1;)Lscala/Function1;
   5:   areturn

public scala.Function1 compose(scala.Function1);
  Code:
   0:   aload_0
   1:   aload_1
   2:   invokestatic    #60; //Method scala/Function1$class.compose:(Lscala/Function1;Lscala/Function1;)Lscala/Function1;
   5:   areturn

public java.lang.String toString();
  Code:
   0:   aload_0
   1:   invokestatic    #65; //Method scala/Function1$class.toString:(Lscala/Function1;)Ljava/lang/String;
   4:   areturn

}

匿名类从Function1特征中获取了大量代码。这确实在类加载开销方面有成本,但不影响分配对象或执行代码的成本。另一个成本是整数的装箱和拆箱。希望这个成本将在2.8的@specialized注释中消失。

答案 1 :(得分:8)

如果您担心Scala性能,最好熟悉1)Java字节码的执行方式,以及2)Scala如何转换为Java字节码。如果您对查看原始字节码或对其进行反编译感到满意,我建议您在可能关注性能的区域进行。您很快就会感觉到Scala如何转换为字节码。如果没有,您可以使用scalac -print标志,该标志打印Scala代码的“完全脱落”版本。它基本上是一个尽可能接近Java的代码版本,就在它变成字节码之前。

我使用您发布的代码创建了一个Performance.scala文件:

jorge-ortizs-macbook-pro:sandbox jeortiz$ cat Performance.scala 
object Performance {
  def sumDoubles(n: Int): Int = {
      def dbl(a: Int) = 2 * a;
      if(n > 0)
          dbl(n) + sumDoubles(n - 1)
      else
          0               
  }
}

当我在其上运行scalac -print时,我可以看到脱落的Scala:

jorge-ortizs-macbook-pro:sandbox jeortiz$ scalac Performance.scala -print
[[syntax trees at end of cleanup]]// Scala source: Performance.scala
package <empty> {
  final class Performance extends java.lang.Object with ScalaObject {
    @remote def $tag(): Int = scala.ScalaObject$class.$tag(Performance.this);
    def sumDoubles(n: Int): Int = {
      if (n.>(0))
        Performance.this.dbl$1(n).+(Performance.this.sumDoubles(n.-(1)))
      else
        0
    };
    final private[this] def dbl$1(a: Int): Int = 2.*(a);
    def this(): object Performance = {
      Performance.super.this();
      ()
    }
  }
}

然后您会注意到dbl已被“解除”为final private[this]属于同一对象的sumDoubles方法。 dblsumDoubles实际上都是包含对象的方法,而不是函数。以非尾递归方式调用它们可能会增加堆栈,但它不会实例化堆上的对象。

答案 2 :(得分:0)

在这种特殊情况下,编译器可能会对此进行优化,但请考虑以下(伪代码)。

def func(n) = {
    def nTimes(a) = n * a
    if (n <= 1)
        1
    else
        nTimes(func(n - 1))
}

在大多数情况下,内部函数需要访问其外部函数的变量,因此必须在每次调用中重新实例化。