如何防止Task上的同步延续?

时间:2014-03-22 14:56:55

标签: c# .net task-parallel-library task async-await

我有一些库(套接字网络)代码,它基于Task为待处理的请求响应提供基于TaskCompletionSource<T>的API。然而,TPL中的一个烦恼是,似乎不可能阻止同步延续。我喜欢能够做的是:

  • 告诉TaskCompletionSource<T>不应允许来电者附加TaskContinuationOptions.ExecuteSynchronously,或
  • 设置结果(SetResult / TrySetResult),指定应忽略TaskContinuationOptions.ExecuteSynchronously,而不是使用池

具体来说,我遇到的问题是传入的数据正在由专用的阅读器处理,如果调用者可以使用TaskContinuationOptions.ExecuteSynchronously进行附加,则他们可以使阅读器停止运行(这不仅影响它们)。以前,我通过一些hackery来解决这个问题,该hackery检测是否存在任何延续,如果是,它会将完成推送到ThreadPool,但是如果调用者有,则会产生重大影响他们的工作队列已经饱和,因为完成工作不会得到及时处理。如果他们使用Task.Wait()(或类似),那么他们将基本上陷入僵局。同样,这就是读者使用专用线程而不是使用工作者的原因。

因此;在我尝试唠叨TPL团队之前:我错过了一个选项吗?

关键点:

  • 我不希望外部呼叫者能够劫持我的线程
  • 我无法使用ThreadPool作为实现,因为它需要在池饱和时才能工作

以下示例生成输出(排序可能因时间而异):

Continuation on: Main thread
Press [return]
Continuation on: Thread pool

问题在于随机调用者设法在“主线程”上获得延续。在实际代码中,这将打断主要读者;坏事!

代码:

using System;
using System.Threading;
using System.Threading.Tasks;

static class Program
{
    static void Identify()
    {
        var thread = Thread.CurrentThread;
        string name = thread.IsThreadPoolThread
            ? "Thread pool" : thread.Name;
        if (string.IsNullOrEmpty(name))
            name = "#" + thread.ManagedThreadId;
        Console.WriteLine("Continuation on: " + name);
    }
    static void Main()
    {
        Thread.CurrentThread.Name = "Main thread";
        var source = new TaskCompletionSource<int>();
        var task = source.Task;
        task.ContinueWith(delegate {
            Identify();
        });
        task.ContinueWith(delegate {
            Identify();
        }, TaskContinuationOptions.ExecuteSynchronously);
        source.TrySetResult(123);
        Console.WriteLine("Press [return]");
        Console.ReadLine();
    }
}

6 个答案:

答案 0 :(得分:48)

.NET 4.6中的新功能:

.NET 4.6包含一个新的TaskCreationOptionsRunContinuationsAsynchronously


由于您愿意使用Reflection访问私有字段...

您可以使用TASK_STATE_THREAD_WAS_ABORTED标记标记TCS的任务,这将导致所有延续不被内联。

const int TASK_STATE_THREAD_WAS_ABORTED = 134217728;

var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.NonPublic | BindingFlags.Instance);
stateField.SetValue(task, (int) stateField.GetValue(task) | TASK_STATE_THREAD_WAS_ABORTED);

修改

我建议您使用表达式,而不是使用Reflection emit。这更具可读性,并且具有与PCL兼容的优势:

var taskParameter = Expression.Parameter(typeof (Task));
const string stateFlagsFieldName = "m_stateFlags";
var setter =
    Expression.Lambda<Action<Task>>(
        Expression.Assign(Expression.Field(taskParameter, stateFlagsFieldName),
            Expression.Or(Expression.Field(taskParameter, stateFlagsFieldName),
                Expression.Constant(TASK_STATE_THREAD_WAS_ABORTED))), taskParameter).Compile();

不使用反射:

如果有人感兴趣,我已经找到了一种没有反思的方法,但它有点“肮脏”#34;同时,当然还有不可忽视的性能惩罚:

try
{
    Thread.CurrentThread.Abort();
}
catch (ThreadAbortException)
{
    source.TrySetResult(123);
    Thread.ResetAbort();
}

答案 1 :(得分:9)

我认为TPL中没有任何内容可以提供对<{1}}延续的显式 API控制。我决定保留initial answer来控制TaskCompletionSource.SetResult场景的此行为。

这是另一个在async/await上强加异步的解决方案,如果ContinueWith触发的继续发生在调用tcs.SetResult的同一个线程上:

SetResult

已更新以解决评论:

  

我无法控制来电者 - 我无法让他们使用特定的   继续变种:如果可以的话,问题将不存在于   第一名

我不知道你不能控制来电者。但是,如果您无法控制它,您可能也不会将public static class TaskExt { static readonly ConcurrentDictionary<Task, Thread> s_tcsTasks = new ConcurrentDictionary<Task, Thread>(); // SetResultAsync static public void SetResultAsync<TResult>( this TaskCompletionSource<TResult> @this, TResult result) { s_tcsTasks.TryAdd(@this.Task, Thread.CurrentThread); try { @this.SetResult(result); } finally { Thread thread; s_tcsTasks.TryRemove(@this.Task, out thread); } } // ContinueWithAsync, TODO: more overrides static public Task ContinueWithAsync<TResult>( this Task<TResult> @this, Action<Task<TResult>> action, TaskContinuationOptions continuationOptions = TaskContinuationOptions.None) { return @this.ContinueWith((Func<Task<TResult>, Task>)(t => { Thread thread = null; s_tcsTasks.TryGetValue(t, out thread); if (Thread.CurrentThread == thread) { // same thread which called SetResultAsync, avoid potential deadlocks // using thread pool return Task.Run(() => action(t)); // not using thread pool (TaskCreationOptions.LongRunning creates a normal thread) // return Task.Factory.StartNew(() => action(t), TaskCreationOptions.LongRunning); } else { // continue on the same thread var task = new Task(() => action(t)); task.RunSynchronously(); return Task.FromResult(task); } }), continuationOptions).Unwrap(); } } 对象直接传递给调用方。从逻辑上讲,您将传递令牌部分,即TaskCompletionSource。在这种情况下,通过在上面添加另一种扩展方法,解决方案可能更容易:

tcs.Task

使用:

// ImposeAsync, TODO: more overrides
static public Task<TResult> ImposeAsync<TResult>(this Task<TResult> @this)
{
    return @this.ContinueWith(new Func<Task<TResult>, Task<TResult>>(antecedent =>
    {
        Thread thread = null;
        s_tcsTasks.TryGetValue(antecedent, out thread);
        if (Thread.CurrentThread == thread)
        {
            // continue on a pool thread
            return antecedent.ContinueWith(t => t, 
                TaskContinuationOptions.None).Unwrap();
        }
        else
        {
            return antecedent;
        }
    }), TaskContinuationOptions.ExecuteSynchronously).Unwrap();
}

这实际上适用于// library code var source = new TaskCompletionSource<int>(); var task = source.Task.ImposeAsync(); // ... // client code task.ContinueWith(delegate { Identify(); }, TaskContinuationOptions.ExecuteSynchronously); // ... // library code source.SetResultAsync(123); await fiddle )并且不受反射黑客攻击。

答案 2 :(得分:3)

而不是做什么

var task = source.Task;

你这样做

var task = source.Task.ContinueWith<Int32>( x => x.Result );

因此,您总是添加一个将以异步方式执行的延续,然后如果订阅者想要在同一个上下文中继续,则无关紧要。它有点干预任务,不是吗?

答案 3 :(得分:3)

已更新,我发布separate answer来处理ContinueWith而不是await(因为ContinueWith并不关心当前同步上下文)。

您可以使用哑同步上下文在SetResult/SetCancelled/SetException上调用TaskCompletionSource触发的延续时强加异步。我相信当前的同步上下文(在await tcs.Task点)是TPL用来决定是继续同步还是异步的标准。

以下适用于我:

if (notifyAsync)
{
    tcs.SetResultAsync(null);
}
else
{
    tcs.SetResult(null);
}

SetResultAsync的实现方式如下:

public static class TaskExt
{
    static public void SetResultAsync<T>(this TaskCompletionSource<T> tcs, T result)
    {
        FakeSynchronizationContext.Execute(() => tcs.SetResult(result));
    }

    // FakeSynchronizationContext
    class FakeSynchronizationContext : SynchronizationContext
    {
        private static readonly ThreadLocal<FakeSynchronizationContext> s_context =
            new ThreadLocal<FakeSynchronizationContext>(() => new FakeSynchronizationContext());

        private FakeSynchronizationContext() { }

        public static FakeSynchronizationContext Instance { get { return s_context.Value; } }

        public static void Execute(Action action)
        {
            var savedContext = SynchronizationContext.Current;
            SynchronizationContext.SetSynchronizationContext(FakeSynchronizationContext.Instance);
            try
            {
                action();
            }
            finally
            {
                SynchronizationContext.SetSynchronizationContext(savedContext);
            }
        }

        // SynchronizationContext methods

        public override SynchronizationContext CreateCopy()
        {
            return this;
        }

        public override void OperationStarted()
        {
            throw new NotImplementedException("OperationStarted");
        }

        public override void OperationCompleted()
        {
            throw new NotImplementedException("OperationCompleted");
        }

        public override void Post(SendOrPostCallback d, object state)
        {
            throw new NotImplementedException("Post");
        }

        public override void Send(SendOrPostCallback d, object state)
        {
            throw new NotImplementedException("Send");
        }
    }
}

SynchronizationContext.SetSynchronizationContext is very cheap就其增加的开销而言。实际上,implementation of WPF Dispatcher.BeginInvoke采用了非常类似的方法。

TPL将await点的目标同步上下文与tcs.SetResult点的目标同步上下文进行比较。如果同步上下文相同(或两个位置都没有同步上下文),则直接同步调用continuation。否则,它在目标同步上下文中使用SynchronizationContext.Post排队,即正常await行为。这种方法的作用总是强加SynchronizationContext.Post行为(如果没有目标同步上下文,则表示池线程延续)。

已更新,这对task.ContinueWith无效,因为ContinueWith并不关心当前的同步上下文。但它适用于await taskfiddle)。它也适用于await task.ConfigureAwait(false)

OTOH,this approach适用于ContinueWith

答案 4 :(得分:3)

如果你能并且准备好使用反射,那么应该这样做;

public static class MakeItAsync
{
    static public void TrySetAsync<T>(this TaskCompletionSource<T> source, T result)
    {
        var continuation = typeof(Task).GetField("m_continuationObject", BindingFlags.NonPublic | BindingFlags.GetField | BindingFlags.Instance);
        var continuations = (List<object>)continuation.GetValue(source.Task);

        foreach (object c in continuations)
        {
            var option = c.GetType().GetField("m_options", BindingFlags.NonPublic | BindingFlags.GetField | BindingFlags.Instance);
            var options = (TaskContinuationOptions)option.GetValue(c);

            options &= ~TaskContinuationOptions.ExecuteSynchronously;
            option.SetValue(c, options);
        }

        source.TrySetResult(result);
    }        
}

答案 5 :(得分:3)

simulate abort方法看起来非常好,但导致了TPL劫持线程in some scenarios

然后我有一个类似于checking the continuation object的实现,但只是检查任何延续,因为实际上有太多场景让给定代码运行良好,但这意味着甚至像Task.Wait之类的东西也会导致线程池查找。

最终,在检查了大量IL之后,唯一安全且有用的场景是SetOnInvokeMres场景(手动重置 - 事件 - 苗条延续)。还有很多其他场景:

  • 有些不安全,导致线程劫持
  • 其余的没用,因为它们最终会导致线程池

所以最后,我选择检查一个非null的continuation-object;如果它是null,那么(没有延续);如果它是非空的,SetOnInvokeMres的特殊情况检查 - 如果是:罚款(可以安全地调用);否则,让线程池执行TrySetComplete,而不告诉任务执行任何特殊操作,如欺骗中止。 Task.Wait使用SetOnInvokeMres方法,这是我们想要尝试的特定方案真的很难不死锁。

Type taskType = typeof(Task);
FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{
    var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
    var il = method.GetILGenerator();
    var hasContinuation = il.DefineLabel();
    il.Emit(OpCodes.Ldarg_0);
    il.Emit(OpCodes.Ldfld, continuationField);
    Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
    // check if null
    il.Emit(OpCodes.Brtrue_S, nonNull);
    il.MarkLabel(goodReturn);
    il.Emit(OpCodes.Ldc_I4_1);
    il.Emit(OpCodes.Ret);

    // check if is a SetOnInvokeMres - if so, we're OK
    il.MarkLabel(nonNull);
    il.Emit(OpCodes.Ldarg_0);
    il.Emit(OpCodes.Ldfld, continuationField);
    il.Emit(OpCodes.Isinst, safeScenario);
    il.Emit(OpCodes.Brtrue_S, goodReturn);

    il.Emit(OpCodes.Ldc_I4_0);
    il.Emit(OpCodes.Ret);

    IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));