如何使用C#异步/等待作为独立的CPS转换

时间:2019-05-16 14:55:27

标签: c# async-await stack-overflow combinators continuation-passing

注1:这里的CPS代表"continuation passing style"

我对了解如何挂接到C#异步机制非常感兴趣。 基本上,据我了解的C#异步/等待功能,编译器将执行CPS转换,然后将转换后的代码传递给上下文对象,该对象管理各个线程上的任务调度。

您是否认为可以利用该编译器功能来创建 功能强大的组合器,而保留默认的线程方面?

一个例子可能是可以递归化和记忆

之类的方法的东西
async MyTask<BigInteger> Fib(int n)     // hypothetical example
{
    if (n <= 1) return n;
    return await Fib(n-1) + await Fib(n-2);
}

我设法做到了:

void Fib(int n, Action<BigInteger> Ret, Action<int, Action<BigInteger>> Rec)
{
    if (n <= 1) Ret(n);
    else Rec(n-1, x => Rec(n-2, y => Ret(x + y)));
}

(不使用异步,非常笨拙...)

或使用monadWhile<X> = Either<X, While<X>>

While<X> Fib(int n) => n <= 1 ?
    While.Return((BigInteger) n) :
    from x in Fib(n-1)
    from y in Fib(n-2)
    select x + y;

好一点,但看起来不像异步语法:)


我在the blog of E. Lippert上问了这个问题,他很友好,让我知道这确实是可能的。


实现ZBDD库时对我的需求:(一种特殊的DAG)

  • 很多复杂的相互递归操作

  • 堆栈在实际示例中不断溢出

  • 仅在完全记忆后才适用

手动CPS和去递归非常繁琐且容易出错。


对我进行的酸性测试(堆栈安全)类似于:

async MyTask<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;
    return await Fib(n - 1, a + b, a);
}

使用默认行为在Fib(10000, 1, 0)上产生堆栈溢出。甚至更好的是,使用开头带有备注的代码来计算Fib(10000)


CPS简而言之:

通常,(忽略例外)任何形式的程序:

A Foo(B b)
{
    ...
    ...
    return x;
    ...
    return y;
}

可以按如下所示的等价连续传递样式进行转换,

void AsyncFoo(B b, Action<A> cont)
{
    ...
    ...
    cont(x);    // a tail call;
    return;
    ...
    cont(y);    // another tail call;
    return;
}

实际上,保证方法cont和方法AsyncFoo都不会返回,因此它们的返回类型应该是底部类型。在这种情况下,运行时在调用a时丢弃堆栈的内容是正确的 返回底部的函数(C#现在没有此功能)。

4 个答案:

答案 0 :(得分:1)

这是我的解决方案版本。它是堆栈安全的,并且不使用线程池,但是有特定的限制。特别是,它需要方法的尾部递归样式,因此Fib(n-1) + Fib(n-2)之类的构造将不起作用。另一方面,实际上以迭代方式执行的尾部递归性质不需要记忆,因为每个迭代都被调用一次。它没有边缘保护,但它是原型而不是最终解决方案:

public class RecursiveTask<T>
{
    private T _result;

    private Func<RecursiveTask<T>> _function;

    public T Result
    {
        get
        {
            var current = this;
            var last = current;

            do
            {
                last = current;
                current = current._function?.Invoke();
            } while (current != null);

            return last._result;
        }
    }

    private RecursiveTask(Func<RecursiveTask<T>> function)
    {
        _function = function;
    }

    private RecursiveTask(T result)
    {
        _result = result;
    }

    public static implicit operator RecursiveTask<T>(T result)
    {
        return new RecursiveTask<T>(result);
    }

    public static RecursiveTask<T> FromFunc(Func<RecursiveTask<T>> func) => new RecursiveTask<T>(func);
}

以及用法:

class Program
{
    static RecursiveTask<int> Fib(int n, int a, int b)
    {
        if (n == 0) return a;
        if (n == 1) return b;

        return RecursiveTask<int>.FromFunc(() => Fib(n - 1, b, a + b));
    }

    static RecursiveTask<int> Factorial(int n, int a)
    {
        if (n == 0) return a;

        return RecursiveTask<int>.FromFunc(() => Factorial(n - 1, n * a));
    }


    static void Main(string[] args)
    {
        Console.WriteLine(Factorial(5, 1).Result);
        Console.WriteLine(Fib(100000, 0, 1).Result);
    }
}

请注意,重要的是返回一个包装循环调用的函数,而不是调用本身,以避免真正的递归。

更新 下面是另一个仍然不使用CPS转换但允许使用接近代数递归的语义的实现,即它支持一个函数内的多个类似于递归的调用,并且不需要函数为尾递归。

public class RecursiveTask<T1, T2>
{
    private readonly Func<RecursiveTask<T1, T2>, T1, T2> _func;
    private readonly Dictionary<T1, RecursiveTask<T1, T2>> _allTasks;
    private readonly List<RecursiveTask<T1, T2>> _subTasks;
    private readonly RecursiveTask<T1, T2> _rootTask;
    private T1 _arg;
    private T2 _result;
    private int _runsCount;
    private bool _isCompleted;
    private bool _isEvaluating;

    private RecursiveTask(Func<RecursiveTask<T1, T2>, T1, T2> func)
    {
        _func = func;
        _allTasks = new Dictionary<T1, RecursiveTask<T1, T2>>();
        _subTasks = new List<RecursiveTask<T1, T2>>();
        _rootTask = this;
    }

    private RecursiveTask(Func<RecursiveTask<T1, T2>, T1, T2> func, T1 arg, RecursiveTask<T1, T2> rootTask) : this(func)
    {
        _arg = arg;
        _rootTask = rootTask;
    }

    public T2 Run(T1 arg)
    {
        if (!_isEvaluating)
            BuildTasks(arg);

        if (_isEvaluating)
            return EvaluateTasks(arg);

        return default;
    }

    public static RecursiveTask<T1, T2> Create(Func<RecursiveTask<T1, T2>, T1, T2> func)
    {
        return new RecursiveTask<T1, T2>(func);
    }

    private void AddSubTask(T1 arg)
    {
        if (!_allTasks.TryGetValue(arg, out RecursiveTask<T1, T2> subTask))
        {
            subTask = new RecursiveTask<T1, T2>(_func, arg, this);
            _allTasks.Add(arg, subTask);
            _subTasks.Add(subTask);
        }
    }

    private T2 Run()
    {
        if (!_isCompleted)
        {
            var runsCount = _rootTask._runsCount;
            _result = _func(_rootTask, _arg);
            _isCompleted = runsCount == _rootTask._runsCount;
        }
        return _result;
    }

    private void BuildTasks(T1 arg)
    {
        if (_runsCount++ == 0)
            _arg = arg;

        if (EqualityComparer<T1>.Default.Equals(_arg, arg))
        {
            Run();

            var processed = 0;
            var addedTasksCount = _subTasks.Count;
            while (processed < addedTasksCount)
            {
                for (var i = processed; i < addedTasksCount; i++, processed++)
                    _subTasks[i].Run();
                addedTasksCount = _subTasks.Count;
            }
            _isEvaluating = true;
        }
        else
            AddSubTask(arg);
    }

    private T2 EvaluateTasks(T1 arg)
    {
        if (EqualityComparer<T1>.Default.Equals(_arg, arg))
        {
            foreach (var task in Enumerable.Reverse(_subTasks))
                task.Run();

            return Run();
        }
        else
        {
            if (_allTasks.TryGetValue(arg, out RecursiveTask<T1, T2> task))
                return task._isCompleted ? task._result : task.Run();
            else
                return default;
        }
    }
}

用法:

class Program
{
    static int Fib(int num)
    {
        return RecursiveTask<int, int>.Create((t, n) =>
        {
            if (n == 0) return 0;
            if (n == 1) return 1;

            return t.Run(n - 1) + t.Run(n - 2);
        }).Run(num);
    }

    static void Main(string[] args)
    {
        Console.WriteLine(Fib(7));
        Console.WriteLine(Fib(100000));
    }
}

作为好处,它具有堆栈安全性,不使用线程池,不负担async await基础结构,使用备忘录并且允许使用或多或少的可读性。当前的实现意味着仅对带有单个参数的函数使用。为了使其适用于更广泛的功能,应为不同的一组通用参数提供类似的实现:

RecursiveTask<T1, T2, T3>
RecursiveTask<T1, T2, T3, T4>
...

答案 1 :(得分:0)

  

对我进行的酸性测试(堆栈安全)类似于:

async MyTask<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;
    return await Fib(n - 1, a + b, a);
}

不是那么简单

public static Task<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return Task.FromResult(b);
    if (n == 1) return Task.FromResult(a);

    return Task.Run(() => Fib(n - 1, a + b, a));
}


或者,不使用线程池,

public static async Task<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;

    return await Task.FromResult(a + b).ContinueWith(t => Fib(n - 1, t.Result, a), TaskScheduler.FromCurrentSynchronizationContext()).Unwrap();
}

,除非我严重误解了某些东西。

答案 2 :(得分:0)

如果不查看您的MyTask<T>并查看该异常的堆栈跟踪,就不可能知道发生了什么。

看起来您正在寻找的是Generalized async return types

您可以浏览the source来了解ValueTaskValueTask<T>的操作。

答案 3 :(得分:0)

以下是我所追求但仍未完全令人满意的解决方案。 它基于GSerg提出的堆栈安全解决方案的见解,并添加了备注。

Pro 算法的核心(FibAux方法使用的是干净的async / await语法)。

缺点:它仍在使用线程池执行。

    // Core algorithm using the cute async/await syntax
    // (n.b. this would be exponential without memoization.)
    private static async Task<BigInteger> FibAux(int n)
    {
        if (n <= 1) return n;
        return await Rec(n - 1) + await Rec(n - 2);
    }

    public static Func<int, Task<BigInteger>> Rec { get; }
        = Utils.StackSafeMemoize<int, BigInteger>(FibAux);

    public static BigInteger Fib(int n)
        => FibAux(n).Result;

    [Test]
    public void Test()
    {
        Console.WriteLine(Fib(100000));
    }

    public static class Utils
    {
        // the combinator (still using the thread pool for execution)
        public static Func<X, Task<Y>> StackSafeMemoize<X, Y>(Func<X, Task<Y>> func)
        {
            var memo = new Dictionary<X, Y>();
            return x =>
            {
                Y result;
                if (!memo.TryGetValue(x, out result))
                {
                    return Task.Run(() => func(x).ContinueWith(task =>
                    {
                        var y = task.Result;
                        memo[x] = y;
                        return y;
                    }));
                }

                return Task.FromResult(result);
            };
        }
    } 

为了进行比较,这是不使用async / await的cps版本。


    public static BigInteger Fib(int n)
    {
        var fib = Memo<int, BigInteger>((m, rec, cont) =>
        {
            if (m <= 1) cont(m);
            else rec(m - 1, x => rec(m - 2, y => cont(x + y)));
        });

        return fib(n);
    }

    [Test]
    public void Test()
    {
        Console.WriteLine(Fib(100000));
    }

    // ---------

    public static Func<X, Y> Memo<X, Y>(Action<X, Action<X, Action<Y>>, Action<Y>> func)
    {
        var memo = new Dictionary<X, Y>(); // can be a Lru cache
        var stack = new Stack<Action>();

        Action<X, Action<Y>> rec = null;
        rec = (x, cont) =>
        {
            stack.Push(() =>
            {
                Y res;
                if (memo.TryGetValue(x, out res))
                {
                    cont(res);
                }
                else
                {
                    func(x, rec, y =>
                    {
                        memo[x] = y;
                        cont(y);
                    });
                }
            });
        };

        return x =>
        {
            var res = default(Y);
            rec(x, y => res = y);
            while (stack.Count > 0)
            {
                var next = stack.Pop();
                next();
            }

            return res;
        };
    }