Java伪尾调用递归产生更好的性能

时间:2016-02-16 03:50:24

标签: java performance data-structures jvm

我只是编写了一个简单的实用程序来计算linkedList的长度,以便linkedList不会托管其大小/长度的“内部”计数器。考虑到这一点,我有三种简单的方法:

  1. 迭代linkList,直到您点击“结束”
  2. 递归计算长度
  3. 递归计算长度,无需将控制权返回给调用函数(使用一些尾调用递归)
  4. 以下是一些捕获这三种情况的代码:

    // 1. iterative approach
    public static <T> int getLengthIteratively(LinkedList<T> ll) {
    
        int length = 0;
        for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
            length++;
        }
    
        return length;
    }
    
    // 2. recursive approach
    public static <T> int getLengthRecursively(LinkedList<T> ll) {
        return getLengthRecursively(ll.getHead());
    }
    
    private static <T> int getLengthRecursively(Node<T> ptr) {
    
        if (ptr == null) {
            return 0;
        } else {
            return 1 + getLengthRecursively(ptr.getNext());
        }
    }
    
    // 3. Pseudo tail-recursive approach
    public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
        return getLengthWithFakeTailRecursion(ll.getHead());
    }
    
    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
        return getLengthWithFakeTailRecursion(ptr, 0);
    }
    
    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
        if (ptr == null) {
            return result;
        } else {
            return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
        }
    }
    

    现在我知道JVM不支持开箱即用的尾递归,但是当我运行一些简单的测试,这些测试将字符串列表链接到~10k节点时,我注意到getLengthWithFakeTailRecursion一直表现优异getLengthRecursively方法(约40%)。 delta是否只能归因于对于情况#2每个节点都传回控制这一事实,我们被迫遍历所有堆栈帧?

    编辑:这是我用来检查性能数字的简单测试:

    public class LengthCheckerTest {
    
    @Test
    public void testLengthChecking() {
    
        LinkedList<String> ll = new LinkedList<String>();
        int sizeOfList = 12000;
        // int sizeOfList = 100000; // Danger: This causes a stackOverflow in recursive methods!
        for (int i = 1; i <= sizeOfList; i++) {
            ll.addNode(String.valueOf(i));
        }
    
        long currTime = System.nanoTime();
        Assert.assertEquals(sizeOfList, LengthChecker.getLengthIteratively(ll));
        long totalTime = System.nanoTime() - currTime;
        System.out.println("totalTime taken with iterative approach: " + (totalTime / 1000) + "ms");
    
        currTime = System.nanoTime();
        Assert.assertEquals(sizeOfList, LengthChecker.getLengthRecursively(ll));
        totalTime = System.nanoTime() - currTime;
        System.out.println("totalTime taken with recursive approach: " + (totalTime / 1000) + "ms");
    
        // Interestingly, the fakeTailRecursion always runs faster than the vanillaRecursion
        // TODO: Look into whether stack-frame collapsing has anything to do with this
        currTime = System.nanoTime();
        Assert.assertEquals(sizeOfList, LengthChecker.getLengthWithFakeTailRecursion(ll));
        totalTime = System.nanoTime() - currTime;
        System.out.println("totalTime taken with fake TCR approach: " + (totalTime / 1000) + "ms");
    }
    }
    

1 个答案:

答案 0 :(得分:3)

您的基准测试方法存在缺陷。您在同一个JVM中执行所有三个测试,因此它们不在同等位置。执行假尾测试时,LinkedListNode类已经进行了JIT编译,因此它的工作速度更快。您可以更改测试的顺序,您将看到不同的数字。每个测试都应该在单独的JVM中执行。

让我们为您的案例写简单的JMH microbenchmark

import java.util.concurrent.TimeUnit;

import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.annotations.*;

// 5 warm-up iterations, 500 ms each, then 10 measurement iterations 500 ms each
// repeat everything three times (with JVM restart)
// output average time in microseconds
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(3)
@State(Scope.Benchmark)
public class ListTest {
    // You did not supply Node and LinkedList implementation
    // but I assume they look like this
    static class Node<T> {
        final T value;
        Node<T> next;

        public Node(T val) {value = val;}
        public void add(Node<T> n) {next = n;}

        public Node<T> getNext() {return next;}
    }

    static class LinkedList<T> {
        Node<T> head;

        public void setHead(Node<T> h) {head = h;}
        public Node<T> getHead() {return head;}
    }

    // Code from your question follows

    // 1. iterative approach
    public static <T> int getLengthIteratively(LinkedList<T> ll) {

        int length = 0;
        for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
            length++;
        }

        return length;
    }

    // 2. recursive approach
    public static <T> int getLengthRecursively(LinkedList<T> ll) {
        return getLengthRecursively(ll.getHead());
    }

    private static <T> int getLengthRecursively(Node<T> ptr) {

        if (ptr == null) {
            return 0;
        } else {
            return 1 + getLengthRecursively(ptr.getNext());
        }
    }

    // 3. Pseudo tail-recursive approach
    public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
        return getLengthWithFakeTailRecursion(ll.getHead());
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
        return getLengthWithFakeTailRecursion(ptr, 0);
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
        if (ptr == null) {
            return result;
        } else {
            return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
        }
    }

    // Benchmarking code

    // Measure for different list length        
    @Param({"10", "100", "1000", "10000"})
    int n;
    LinkedList<Integer> list;

    @Setup    
    public void setup() {
        list = new LinkedList<>();
        Node<Integer> cur = new Node<>(0);
        list.setHead(cur);
        for(int i=1; i<n; i++) {
            Node<Integer> next = new Node<>(i);
            cur.add(next);
            cur = next;
        }
    }

    // Do not forget to return result to the caller, so it's not optimized out
    @Benchmark    
    public int testIteratively() {
        return getLengthIteratively(list);
    }

    @Benchmark    
    public int testRecursively() {
        return getLengthRecursively(list);
    }

    @Benchmark    
    public int testRecursivelyFakeTail() {
        return getLengthWithFakeTailRecursion(list);
    }
}

这是我机器上的结果(x64 Win7,Java 8u71)

Benchmark                           (n)  Mode  Cnt   Score    Error  Units
ListTest.testIteratively             10  avgt   30   0,009 ±  0,001  us/op
ListTest.testIteratively            100  avgt   30   0,156 ±  0,001  us/op
ListTest.testIteratively           1000  avgt   30   2,248 ±  0,036  us/op
ListTest.testIteratively          10000  avgt   30  26,416 ±  0,590  us/op
ListTest.testRecursively             10  avgt   30   0,014 ±  0,001  us/op
ListTest.testRecursively            100  avgt   30   0,191 ±  0,003  us/op
ListTest.testRecursively           1000  avgt   30   3,599 ±  0,031  us/op
ListTest.testRecursively          10000  avgt   30  40,071 ±  0,328  us/op
ListTest.testRecursivelyFakeTail     10  avgt   30   0,015 ±  0,001  us/op
ListTest.testRecursivelyFakeTail    100  avgt   30   0,190 ±  0,002  us/op
ListTest.testRecursivelyFakeTail   1000  avgt   30   3,609 ±  0,044  us/op
ListTest.testRecursivelyFakeTail  10000  avgt   30  41,534 ±  1,186  us/op

正如您所看到的,假尾速度与简单的递归速度(在误差范围内)相同,并且比迭代方法慢20-60%。所以你的结果不会被复制。

如果您确实不想获得稳态测量结果,而是单次拍摄(没有预热)的结果,您可以使用以下选项启动相同的基准:-ss -wi 0 -i 1 -f 10。结果将如下:

Benchmark                           (n)  Mode  Cnt    Score     Error  Units
ListTest.testIteratively             10    ss   10   16,095 ±   0,831  us/op
ListTest.testIteratively            100    ss   10   19,780 ±   6,440  us/op
ListTest.testIteratively           1000    ss   10   74,316 ±  26,434  us/op
ListTest.testIteratively          10000    ss   10  366,496 ±  42,299  us/op
ListTest.testRecursively             10    ss   10   19,594 ±   7,084  us/op
ListTest.testRecursively            100    ss   10   21,973 ±   0,701  us/op
ListTest.testRecursively           1000    ss   10  165,007 ±  54,915  us/op
ListTest.testRecursively          10000    ss   10  563,739 ±  74,908  us/op
ListTest.testRecursivelyFakeTail     10    ss   10   19,454 ±   4,523  us/op
ListTest.testRecursivelyFakeTail    100    ss   10   25,518 ±  11,802  us/op
ListTest.testRecursivelyFakeTail   1000    ss   10  158,336 ±  43,646  us/op
ListTest.testRecursivelyFakeTail  10000    ss   10  755,384 ± 232,940  us/op

正如您所看到的,第一次发布的速度比后续发布慢很多倍。而你的结果仍然没有复制。我发现testRecursivelyFakeTail实际上n = 10000实际上较慢(但在预热后它达到与testRecursively相同的峰值速度。