线程代码

时间:2016-05-16 14:11:08

标签: c# multithreading unit-testing

我的应用程序中有一个线程模块,用于管理启动并行操作。它添加了各种时序和日志记录,原因很复杂,我最近发现我编写了一个bug,它在双嵌套线程上启动了一些任务。

即。它叫相当于:

Task.Run(
    async () => await Task.Run(
        () => DoStuff();
    );
).Wait()

现在一方面,该代码有效...目标代码运行,等待代码不会继续,直到目标代码完成。

另一方面,它使用2个线程来执行此操作而不是1,因为我们遇到线程饥饿问题,这是一个问题。

我知道如何修复代码,但是我想编写一个单元测试以确保A)我已经修复了所有这些错误/在所有场景中修复它。 并且B)将来没有人再创造这个错误。

但我看不出如何掌握“我创造的所有主题”。 CurrentProcess.Threads给了我线程的LOADS,没有明显的方法来确定哪些是我关心的。

有什么想法吗?

3 个答案:

答案 0 :(得分:2)

  

如何掌握"我创建的所有主题"

Task.Run不创建任何线程;它会调度作业以在当前配置的线程池上运行。见https://msdn.microsoft.com/library/system.threading.tasks.taskscheduler.aspx

如果您的意思是"如何计算我已加入的任务数量",我认为您需要创建一个TaskScheduler的自定义实施,它会对传入的任务进行计数并配置您的测试代码使用它。上面链接的页面上显示了一个自定义TaskScheduler的示例。

答案 1 :(得分:2)

通常情况下,单元测试的解决方案涉及静态方法(在这种情况下为Task.Run),您可能需要将某些内容作为依赖项传递给包含此内容的类,并且您可以然后在测试中添加行为。

正如@Rich在答案中建议的那样,你可以通过传递TaskScheduler来做到这一点。然后,您的测试版本可以在排队时保留任务的计数。

由于保护级别,进行测试TaskScheduler实际上有点难看,但在这篇文章的底部我已经包含了一个包含现有TaskScheduler的内容(例如,你可以使用TaskScheduler.Default })。

不幸的是,您还需要更改您的电话,例如

Task.Run(() => DoSomething);

类似

Task.Factory.StartNew(
    () => DoSomething(),
    CancellationToken.None,
    TaskCreationOptions.DenyChildAttach,
    myTaskScheduler);

basically what Task.Run does under the hood,除了TaskScheduler.Default。你当然可以在某个地方用一个辅助方法把它包起来。

或者,如果您对测试代码中某些风险较高的反映不感到羞怯,则可能会劫持TaskScheduler.Default属性,因此您仍然可以使用Task.Run

var defaultSchedulerField = typeof(TaskScheduler).GetField("s_defaultTaskScheduler", BindingFlags.Static | BindingFlags.NonPublic);
var scheduler = new TestTaskScheduler(TaskScheduler.Default);
defaultSchedulerField.SetValue(null, scheduler);

(私人字段名称来自TaskScheduler.cs line 285。)

因此,例如,此测试将使用下面的TestTaskScheduler和反射技巧传递:

[Test]
public void Can_count_tasks()
{
    // Given
    var originalScheduler = TaskScheduler.Default;
    var defaultSchedulerField = typeof(TaskScheduler).GetField("s_defaultTaskScheduler", BindingFlags.Static | BindingFlags.NonPublic);
    var testScheduler = new TestTaskScheduler(originalScheduler);
    defaultSchedulerField.SetValue(null, testScheduler);

    // When
    Task.Run(() => {});
    Task.Run(() => {});
    Task.Run(() => {});

    // Then
    testScheduler.TaskCount.Should().Be(3);

    // Clean up
    defaultSchedulerField.SetValue(null, originalScheduler);
}

这是测试任务调度程序:

using System.Collections.Generic;
using System.Reflection;
using System.Threading.Tasks;

public class TestTaskScheduler : TaskScheduler
{
    private static readonly MethodInfo queueTask = GetProtectedMethodInfo("QueueTask");
    private static readonly MethodInfo tryExecuteTaskInline = GetProtectedMethodInfo("TryExecuteTaskInline");
    private static readonly MethodInfo getScheduledTasks = GetProtectedMethodInfo("GetScheduledTasks");

    private readonly TaskScheduler taskScheduler;

    public TestTaskScheduler(TaskScheduler taskScheduler)
    {
        this.taskScheduler = taskScheduler;
    }

    public int TaskCount { get; private set; }

    protected override void QueueTask(Task task)
    {
        TaskCount++;
        CallProtectedMethod(queueTask, task);
    }

    protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
    {
        return (bool)CallProtectedMethod(tryExecuteTaskInline, task, taskWasPreviouslyQueued);
    }

    protected override IEnumerable<Task> GetScheduledTasks()
    {
        return (IEnumerable<Task>)CallProtectedMethod(getScheduledTasks);
    }

    private object CallProtectedMethod(MethodInfo methodInfo, params object[] args)
    {
        return methodInfo.Invoke(taskScheduler, args);
    }

    private static MethodInfo GetProtectedMethodInfo(string methodName)
    {
        return typeof(TaskScheduler).GetMethod(methodName, BindingFlags.Instance | BindingFlags.NonPublic);
    }
}

或使用@hgcummings在评论中建议使用RelflectionMagic进行整理:

var scheduler = new TestTaskScheduler(TaskScheduler.Default);
typeof(TaskScheduler).AsDynamicType().s_defaultTaskScheduler = scheduler;
using System.Collections.Generic;
using System.Threading.Tasks;
using ReflectionMagic;

public class TestTaskScheduler : TaskScheduler
{
    private readonly dynamic taskScheduler;

    public TestTaskScheduler(TaskScheduler taskScheduler)
    {
        this.taskScheduler = taskScheduler.AsDynamic();
    }

    public int TaskCount { get; private set; }

    protected override void QueueTask(Task task)
    {
        TaskCount++;
        taskScheduler.QueueTask(task);
    }

    protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
    {
        return taskScheduler.TryExecuteTaskInline(task, taskWasPreviouslyQueued);
    }

    protected override IEnumerable<Task> GetScheduledTasks()
    {
        return taskScheduler.GetScheduledTasks();
    }
}

答案 2 :(得分:0)

线程类确实有一个Name属性,可用于帮助识别您创建的所有线程。这意味着一个简单的linq或for循环,它可以让你跟踪哪些线程是你的。

https://msdn.microsoft.com/en-us/library/system.threading.thread.name(v=vs.110).aspx