自定义线程池支持异步操作

时间:2014-09-02 21:15:49

标签: c# .net multithreading asynchronous threadpool

我希望有一个满足以下要求的自定义线程池:

  1. 根据池容量预先分配实际线程。如果需要产生并发任务,实际工作可以自由使用标准.NET线程池。
  2. 池必须能够返回空闲线程数。返回的数字可能小于空闲线程的实际数量,但不能更大。当然,数字越准确越好。
  3. 将工作排队到池中应返回相应的Task,这应该与基于任务的API很好。
  4. 最大作业容量(或并行度)应动态调整。试图降低容量不必立即生效,但增加它应该立即生效。
  5. 第一项的基本原理如下所示:

    • 机器不应同时运行超过N个工作项,其中N相对较小 - 介于10到30之间。
    • 从数据库中获取工作,如果获取了K项,那么我们希望确保有K个空闲线程立即开始工作。从数据库中获取工作但仍在等待下一个可用线程的情况是不可接受的。

    最后一项还解释了空闲线程数的原因 - 我将从数据库中获取那么多工作项。它还解释了为什么报告的空闲线程数永远不会高于实际线程数 - 否则我可能会获得更多可立即启动的工作。

    无论如何,这是我的实现以及一个小程序来测试它(BJE代表后台作业引擎):

    using System;
    using System.Collections.Concurrent;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Threading;
    using System.Threading.Tasks;
    
    namespace TaskStartLatency
    {
        public class BJEThreadPool
        {
            private sealed class InternalTaskScheduler : TaskScheduler
            {
                private int m_idleThreadCount;
                private readonly BlockingCollection<Task> m_bus;
    
                public InternalTaskScheduler(int threadCount, BlockingCollection<Task> bus)
                {
                    m_idleThreadCount = threadCount;
                    m_bus = bus;
                }
    
                public void RunInline(Task task)
                {
                    Interlocked.Decrement(ref m_idleThreadCount);
                    try
                    {
                        TryExecuteTask(task);
                    }
                    catch
                    {
                        // The action is responsible itself for the error handling, for the time being...
                    }
                    Interlocked.Increment(ref m_idleThreadCount);
                }
    
                public int IdleThreadCount
                {
                    get { return m_idleThreadCount; }
                }
    
                #region Overrides of TaskScheduler
    
                protected override void QueueTask(Task task)
                {
                    m_bus.Add(task);
                }
    
                protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
                {
                    return TryExecuteTask(task);
                }
    
                protected override IEnumerable<Task> GetScheduledTasks()
                {
                    throw new NotSupportedException();
                }
    
                #endregion
    
                public void DecrementIdleThreadCount()
                {
                    Interlocked.Decrement(ref m_idleThreadCount);
                }
            }
    
            private class ThreadContext
            {
                private readonly InternalTaskScheduler m_ts;
                private readonly BlockingCollection<Task> m_bus;
                private readonly CancellationTokenSource m_cts;
                public readonly Thread Thread;
    
                public ThreadContext(string name, InternalTaskScheduler ts, BlockingCollection<Task> bus, CancellationTokenSource cts)
                {
                    m_ts = ts;
                    m_bus = bus;
                    m_cts = cts;
                    Thread = new Thread(Start)
                    {
                        IsBackground = true,
                        Name = name
                    };
                    Thread.Start();
                }
    
                private void Start()
                {
                    try
                    {
                        foreach (var task in m_bus.GetConsumingEnumerable(m_cts.Token))
                        {
                            m_ts.RunInline(task);
                        }
                    }
                    catch (OperationCanceledException)
                    {
                    }
                    m_ts.DecrementIdleThreadCount();
                }
            }
    
            private readonly InternalTaskScheduler m_ts;
            private readonly CancellationTokenSource m_cts = new CancellationTokenSource();
            private readonly BlockingCollection<Task> m_bus = new BlockingCollection<Task>();
            private readonly List<ThreadContext> m_threadCtxs = new List<ThreadContext>();
    
            public BJEThreadPool(int threadCount)
            {
                m_ts = new InternalTaskScheduler(threadCount, m_bus);
                for (int i = 0; i < threadCount; ++i)
                {
                    m_threadCtxs.Add(new ThreadContext("BJE Thread " + i, m_ts, m_bus, m_cts));
                }
            }
    
            public void Terminate()
            {
                m_cts.Cancel();
                foreach (var t in m_threadCtxs)
                {
                    t.Thread.Join();
                }
            }
    
            public Task Run(Action<CancellationToken> action)
            {
                return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
            }
            public Task Run(Action action)
            {
                return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
            }
    
            public int IdleThreadCount
            {
                get { return m_ts.IdleThreadCount; }
            }
        }
    
        class Program
        {
            static void Main()
            {
                const int THREAD_COUNT = 32;
                var pool = new BJEThreadPool(THREAD_COUNT);
                var tcs = new TaskCompletionSource<bool>();
                var tasks = new List<Task>();
                var allRunning = new CountdownEvent(THREAD_COUNT);
    
                for (int i = pool.IdleThreadCount; i > 0; --i)
                {
                    var index = i;
                    tasks.Add(pool.Run(cancellationToken =>
                    {
                        Console.WriteLine("Started action " + index);
                        allRunning.Signal();
                        tcs.Task.Wait(cancellationToken);
                        Console.WriteLine("  Ended action " + index);
                    }));
                }
    
                Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
    
                allRunning.Wait();
                Debug.Assert(pool.IdleThreadCount == 0);
    
                int expectedIdleThreadCount = THREAD_COUNT;
                Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
                switch (Console.ReadKey().KeyChar)
                {
                case 'c':
                    Console.WriteLine("Cancel All");
                    tcs.TrySetCanceled();
                    break;
                case 'e':
                    Console.WriteLine("Error All");
                    tcs.TrySetException(new Exception("Failed"));
                    break;
                case 'a':
                    Console.WriteLine("Abort All");
                    pool.Terminate();
                    expectedIdleThreadCount = 0;
                    break;
                default:
                    Console.WriteLine("Done All");
                    tcs.TrySetResult(true);
                    break;
                }
                try
                {
                    Task.WaitAll(tasks.ToArray());
                }
                catch (AggregateException exc)
                {
                    Console.WriteLine(exc.Flatten().InnerException.Message);
                }
    
                Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
    
                pool.Terminate();
                Console.WriteLine("Press any key");
                Console.ReadKey();
            }
        }
    }
    

    这是一个非常简单的实现,它似乎正在工作。但是,存在一个问题 - BJEThreadPool.Run方法不接受异步方法。即我的实现不允许我添加以下重载:

    public Task Run(Func<CancellationToken, Task> action)
    {
        return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
    }
    public Task Run(Func<Task> action)
    {
        return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
    }
    

    我在InternalTaskScheduler.RunInline中使用的模式在这种情况下不起作用。

    那么,我的问题是如何添加对异步工作项的支持?只要支持帖子开头的要求,我就可以改变整个设计。

    修改

    我想澄清所需池的意图用法。请注意以下代码:

    if (pool.IdleThreadCount == 0)
    {
      return;
    }
    
    foreach (var jobData in FetchFromDB(pool.IdleThreadCount))
    {
      pool.Run(CreateJobAction(jobData));
    }
    

    注意:

    1. 代码将定期运行,比如每1分钟一次。
    2. 代码将由观看同一数据库的多台计算机同时运行。
    3. FetchFromDB将使用Using SQL Server as a DB queue with multiple clients中描述的技术以原子方式从数据库中获取并锁定工作。
    4. CreateJobAction将调用由jobData表示的代码(作业代码),并在完成该代码后关闭工作。作业代码不受我的控制,它几乎可以是任何东西 - 繁重的CPU绑定代码或轻型异步IO绑定代码,编写严重的同步IO绑定代码或混合使用。它可以运行几分钟,它可以运行几个小时。关闭工作是我的代码,它将通过异步IO绑定代码。因此,返回的作业操作的签名是异步方法的签名。
    5. 第2项强调了正确识别空闲线程数量的重要性。如果有900个待处理工作项和10个代理计算机,我不能允许代理获取300个工作项并在线程池上排队。为什么?因为,代理最不可能同时运行300个工作项。它会运行一些,当然,其他人将在线程池工作队列中等待。假设它将运行100并让200等待(即使100可能是远程的)。这使用3个满载代理和7个空闲代理。但实际上只有900个工作项目在同时处理!

      我的目标是最大限度地扩大可用代理商的工作范围。理想情况下,我应该评估一个代理人的负担和#34;沉重的&#34;待完成的工作,但它是一项艰巨的任务,并保留给未来的版本。现在,我希望为每个代理分配最大作业容量,以便在不重新启动代理的情况下提供动态增加/减少它的方法。

      接下来的观察。这项工作可能需要很长时间才能运行,它可能是所有同步代码。据我所知,利用线程池线程进行此类工作是不可取的。

      编辑2

      有一条声明TaskScheduler仅适用于CPU绑定工作。但是,如果我不知道工作的性质怎么办?我的意思是它是一个通用的后台作业引擎,它可以运行数千种不同的工作。我没有办法告诉&#34;那份工作是CPU限制的&#34;和&#34; on on是同步IO绑定&#34;另一个是异步IO绑定。我希望我能,但我不能。

      编辑3

      最后,我不使用SemaphoreSlim,但我也没有使用TaskScheduler - 它最终在我的厚厚的头骨上滴下来,它是不合适的,而且是完全错误的,加上它使得代码过于复杂。

      尽管如此,我还是没有看到SemaphoreSlim的方式。建议的模式:

      public async Task Enqueue(Func<Task> taskGenerator)
      {
          await semaphore.WaitAsync();
          try
          {
              await taskGenerator();
          }
          finally
          {
              semaphore.Release();
          }
      }
      

      期望taskGenerator是异步IO绑定代码,否则打开新线程。但是,我无法确定要执行的工作是一个还是另一个。另外,正如我从SemaphoreSlim.WaitAsync continuation code学到的,如果信号量被解锁,WaitAsync()之后的代码将在同一个线程上运行,这对我来说不是很好。

      无论如何,以下是我的实施,以防任何人幻想。不幸的是,我还没有理解如何动态地减少池线程数,但这是另一个问题的主题。

      using System;
      using System.Collections.Concurrent;
      using System.Collections.Generic;
      using System.Diagnostics;
      using System.Threading;
      using System.Threading.Tasks;
      
      namespace TaskStartLatency
      {
          public interface IBJEThreadPool
          {
              void SetThreadCount(int threadCount);
              void Terminate();
              Task Run(Action action);
              Task Run(Action<CancellationToken> action);
              Task Run(Func<Task> action);
              Task Run(Func<CancellationToken, Task> action);
              int IdleThreadCount { get; }
          }
      
          public class BJEThreadPool : IBJEThreadPool
          {
              private interface IActionContext
              {
                  Task Run(CancellationToken ct);
                  TaskCompletionSource<object> TaskCompletionSource { get; }
              }
      
              private class ActionContext : IActionContext
              {
                  private readonly Action m_action;
      
                  public ActionContext(Action action)
                  {
                      m_action = action;
                      TaskCompletionSource = new TaskCompletionSource<object>();
                  }
      
                  #region Implementation of IActionContext
      
                  public Task Run(CancellationToken ct)
                  {
                      m_action();
                      return null;
                  }
      
                  public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
      
                  #endregion
              }
              private class CancellableActionContext : IActionContext
              {
                  private readonly Action<CancellationToken> m_action;
      
                  public CancellableActionContext(Action<CancellationToken> action)
                  {
                      m_action = action;
                      TaskCompletionSource = new TaskCompletionSource<object>();
                  }
      
                  #region Implementation of IActionContext
      
                  public Task Run(CancellationToken ct)
                  {
                      m_action(ct);
                      return null;
                  }
      
                  public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
      
                  #endregion
              }
              private class AsyncActionContext : IActionContext
              {
                  private readonly Func<Task> m_action;
      
                  public AsyncActionContext(Func<Task> action)
                  {
                      m_action = action;
                      TaskCompletionSource = new TaskCompletionSource<object>();
                  }
      
                  #region Implementation of IActionContext
      
                  public Task Run(CancellationToken ct)
                  {
                      return m_action();
                  }
      
                  public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
      
                  #endregion
              }
              private class AsyncCancellableActionContext : IActionContext
              {
                  private readonly Func<CancellationToken, Task> m_action;
      
                  public AsyncCancellableActionContext(Func<CancellationToken, Task> action)
                  {
                      m_action = action;
                      TaskCompletionSource = new TaskCompletionSource<object>();
                  }
      
                  #region Implementation of IActionContext
      
                  public Task Run(CancellationToken ct)
                  {
                      return m_action(ct);
                  }
      
                  public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
      
                  #endregion
              }
      
              private readonly CancellationTokenSource m_ctsTerminateAll = new CancellationTokenSource();
              private readonly BlockingCollection<IActionContext> m_bus = new BlockingCollection<IActionContext>();
              private readonly LinkedList<Thread> m_threads = new LinkedList<Thread>();
              private int m_idleThreadCount;
      
              private static int s_threadCount;
      
              public BJEThreadPool(int threadCount)
              {
                  ReserveAdditionalThreads(threadCount);
              }
      
              private void ReserveAdditionalThreads(int n)
              {
                  for (int i = 0; i < n; ++i)
                  {
                      var index = Interlocked.Increment(ref s_threadCount) - 1;
      
                      var t = new Thread(Start)
                      {
                          IsBackground = true,
                          Name = "BJE Thread " + index
                      };
                      Interlocked.Increment(ref m_idleThreadCount);
                      t.Start();
      
                      m_threads.AddLast(t);
                  }
              }
      
              private void Start()
              {
                  try
                  {
                      foreach (var actionContext in m_bus.GetConsumingEnumerable(m_ctsTerminateAll.Token))
                      {
                          RunWork(actionContext).Wait();
                      }
                  }
                  catch (OperationCanceledException)
                  {
                  }
                  catch
                  {
                      // Should never happen - log the error
                  }
      
                  Interlocked.Decrement(ref m_idleThreadCount);
              }
      
              private async Task RunWork(IActionContext actionContext)
              {
                  Interlocked.Decrement(ref m_idleThreadCount);
                  try
                  {
                      var task = actionContext.Run(m_ctsTerminateAll.Token);
                      if (task != null)
                      {
                          await task;
                      }
                      actionContext.TaskCompletionSource.SetResult(null);
                  }
                  catch (OperationCanceledException)
                  {
                      actionContext.TaskCompletionSource.TrySetCanceled();
                  }
                  catch (Exception exc)
                  {
                      actionContext.TaskCompletionSource.TrySetException(exc);
                  }
                  Interlocked.Increment(ref m_idleThreadCount);
              }
      
              private Task PostWork(IActionContext actionContext)
              {
                  m_bus.Add(actionContext);
                  return actionContext.TaskCompletionSource.Task;
              }
      
              #region Implementation of IBJEThreadPool
      
              public void SetThreadCount(int threadCount)
              {
                  if (threadCount > m_threads.Count)
                  {
                      ReserveAdditionalThreads(threadCount - m_threads.Count);
                  }
                  else if (threadCount < m_threads.Count)
                  {
                      throw new NotSupportedException();
                  }
              }
              public void Terminate()
              {
                  m_ctsTerminateAll.Cancel();
                  foreach (var t in m_threads)
                  {
                      t.Join();
                  }
              }
      
              public Task Run(Action action)
              {
                  return PostWork(new ActionContext(action));
              }
              public Task Run(Action<CancellationToken> action)
              {
                  return PostWork(new CancellableActionContext(action));
              }
              public Task Run(Func<Task> action)
              {
                  return PostWork(new AsyncActionContext(action));
              }
              public Task Run(Func<CancellationToken, Task> action)
              {
                  return PostWork(new AsyncCancellableActionContext(action));
              }
      
              public int IdleThreadCount
              {
                  get { return m_idleThreadCount; }
              }
      
              #endregion
          }
      
          public static class Extensions
          {
              public static Task WithCancellation(this Task task, CancellationToken token)
              {
                  return task.ContinueWith(t => t.GetAwaiter().GetResult(), token);
              }
          }
      
          class Program
          {
              static void Main()
              {
                  const int THREAD_COUNT = 16;
                  var pool = new BJEThreadPool(THREAD_COUNT);
                  var tcs = new TaskCompletionSource<bool>();
                  var tasks = new List<Task>();
                  var allRunning = new CountdownEvent(THREAD_COUNT);
      
                  for (int i = pool.IdleThreadCount; i > 0; --i)
                  {
                      var index = i;
                      tasks.Add(pool.Run(async ct =>
                      {
                          Console.WriteLine("Started action " + index);
                          allRunning.Signal();
                          await tcs.Task.WithCancellation(ct);
                          Console.WriteLine("  Ended action " + index);
                      }));
                  }
      
                  Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
      
                  allRunning.Wait();
                  Debug.Assert(pool.IdleThreadCount == 0);
      
                  int expectedIdleThreadCount = THREAD_COUNT;
                  Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
                  switch (Console.ReadKey().KeyChar)
                  {
                  case 'c':
                      Console.WriteLine("ancel All");
                      tcs.TrySetCanceled();
                      break;
                  case 'e':
                      Console.WriteLine("rror All");
                      tcs.TrySetException(new Exception("Failed"));
                      break;
                  case 'a':
                      Console.WriteLine("bort All");
                      pool.Terminate();
                      expectedIdleThreadCount = 0;
                      break;
                  default:
                      Console.WriteLine("Done All");
                      tcs.TrySetResult(true);
                      break;
                  }
      
                  try
                  {
                      Task.WaitAll(tasks.ToArray());
                  }
                  catch (AggregateException exc)
                  {
                      Console.WriteLine(exc.Flatten().InnerException.Message);
                  }
      
                  Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
      
                  pool.Terminate();
                  Console.WriteLine("Press any key");
                  Console.ReadKey();
              }
          }
      }
      

2 个答案:

答案 0 :(得分:3)

异步&#34;工作项目&#34;通常基于异步IO。 Async IO在运行时不使用线程。任务调度程序用于执行CPU工作(基于委托的任务)。概念TaskScheduler不适用。您不能使用自定义TaskScheduler来影响异步代码的作用。

让你的工作项目自行节制:

static SemaphoreSlim sem = new SemaphoreSlim(maxDegreeOfParallelism); //shared object

async Task MyWorkerFunction()
{
    await sem.WaitAsync();
    try
    {
        MyWork();
    }
    finally
    {
        sem.Release();
    }
}

答案 1 :(得分:1)

As mentioned in another answer by usr你不能用TaskScheduler做到这一点,因为它只适用于CPU绑定工作,不限制所有类型工作的并行化水平,无论是否并行。他还向您展示了如何使用SemaphoreSlim异步限制并行度。

您可以通过以下几种方式对其进行扩展以概括这些概念。看起来对您最有利的一种方法是创建一种特殊类型的队列,该队列采用返回Task的操作并以实现给定最大并行度的方式执行它们。

public class FixedParallelismQueue
{
    private SemaphoreSlim semaphore;
    public FixedParallelismQueue(int maxDegreesOfParallelism)
    {
        semaphore = new SemaphoreSlim(maxDegreesOfParallelism,
            maxDegreesOfParallelism);
    }

    public async Task<T> Enqueue<T>(Func<Task<T>> taskGenerator)
    {
        await semaphore.WaitAsync();
        try
        {
            return await taskGenerator();
        }
        finally
        {
            semaphore.Release();
        }
    }
    public async Task Enqueue(Func<Task> taskGenerator)
    {
        await semaphore.WaitAsync();
        try
        {
            await taskGenerator();
        }
        finally
        {
            semaphore.Release();
        }
    }
}

这允许您为应用程序创建一个队列(如果需要,您甚至可以拥有多个单独的队列),这些队列具有固定的并行化程度。然后,您可以在完成后提供返回Task的操作,队列将在可能的情况下安排它并返回表示该工作单元完成时间的Task