Unit testing concurrent code using custom TaskScheduler

Today I had a need to test an action that runs inside a Task:
public class ClassToTest
{
    private readonly IMessageBus _messageBus;
    private CancellationTokenSource _cancellationTokenSource;
    public event EventHandler OnNewMessage;

    public ClassToTest(IMessageBus messageBus)
    {
        _messageBus = messageBus;
    }

    public void Start()
    {
        _cancellationTokenSource = new CancellationTokenSource();
        Task.Run(() =>
        {
            var message = _messageBus.GetNextMessage();

            // Do work

            if (OnNewMessage != null)
            {
                OnNewMessage(this, EventArgs.Empty);
            }

        }, _cancellationTokenSource.Token);
    }

    public void Stop()
    {
        if (_cancellationTokenSource != null)
        {
            _cancellationTokenSource.Cancel();
            _cancellationTokenSource = null;
        }
    }
}
When faced with similar a code a developer has a “tiny” problem – how to force the code inside Task.Run to execute before the end of the test is reached.
And so we can write the following test that would fail - most of the times:
[TestMethod]
public void BadTest()
{
    var fakeMessageBus = A.Fake<IMessageBus>();
    var wasCalled = false;

    var cut = new ClassToTest(fakeMessageBus);
    cut.OnNewMessage += (sender, args) => wasCalled = true;

    cut.Start();

    Assert.IsTrue(wasCalled);
}

Why you should not use Thread.Sleep

The trivial fix would be to use Thread.Sleep (or similar Task.Delay) to make sure that the action would run before we reach the end of the test. Unfortunately adding this sleep will only cause the test would pass most of the times.
[TestMethod]
public void WorseTest()
{
    var fakeMessageBus = A.Fake&lt;IMessageBus&gt;();
    var wasCalled = false;

    var cut = new ClassToTest(fakeMessageBus);
    cut.OnNewMessage += (sender, args) =&gt; wasCalled = true;

    cut.Start();
    Thread.Sleep(1000);

    Assert.IsTrue(wasCalled);
}
The problem is that by adding a sleep we managed to achieve an inconsistent test while also making sure that the test would run a staggering 1 second – if that doesn’t sound like a lot think how much time 60 of such tests would run, what about 1000 tests each taking a single second?
In any case creating a test that could fail or pass inconsistently is a huge problem – it only takes a few such failures to convince your team to stop paying attention to that test – because it’s the test that always fail and soon enough they’ll stop paying attention to the continuous integration results and all of your tests would become useless (at least for alerting the team when a bug is introduced).

Creating a custom TaskScheduler

Looking for a solution I found a good post by Kjetil Klaussen where he suggests creating a new TaskScheduler that would cause the task to run in the same thread as the test thus causing the test to be predictable and simple:
public class CurrentThreadTaskScheduler : TaskScheduler
{
    protected override void QueueTask(Task task)
    {
        TryExecuteTask(task);
    }

    protected override bool TryExecuteTaskInline(
       Task task,
       bool taskWasPreviouslyQueued)
    {
        return TryExecuteTask(task);
    }

    protected override IEnumerable<Task> GetScheduledTasks()
    {
        return Enumerable.Empty<Task>();
    }

    public override int MaximumConcurrencyLevel { get { return 1; } }
}
Now we need to pass that scheduler to the executing task either by using a wrapper object or dependency injection.

Changing the current TaskScheduler

The problem is that you need to change your production code which is not always simple just like in my case where the action I needed to run was 5 calls deep and I didn’t want to pass the new TaskScheduler to each class on the calling stack.
I’ve started to look of a way to replace the current task scheduler.  Tamir D had a better idea - run the entire test inside a Task with CurrentThreadTaskScheduler.
Following his advice I’ve changed the Start method:
public void Start()
{
    _cancellationTokenSource = new CancellationTokenSource();
    Task.Factory.StartNew(() =>
    {
        var message = _messageBus.GetNextMessage();

        // Do work

        if (OnNewMessage != null)
        {
            OnNewMessage(this, EventArgs.Empty);
        }

    }, _cancellationTokenSource.Token, 
    TaskCreationOptions.None, 
    TaskScheduler.Current);
}
And now I could write the following test that would pass every single time.
[TestMethod]
public void GoodTest()
{
    var fakeMessageBus = A.Fake<IMessageBus>();
    var wasCalled = false;

    Task.Factory.StartNew(() =>
    {
        var cut = new ClassToTest(fakeMessageBus);
        cut.OnNewMessage += (sender, args) => wasCalled = true;

        cut.Start();
    }, CancellationToken.None, 
    TaskCreationOptions.None, 
    new CurrentThreadTaskScheduler());

    Assert.IsTrue(wasCalled);
}
Another benefit is that I could also test for negative conditions (something didn’t happen) since now all of the code runs synchronously.

Want to learn more?

There are many other ways to unit test concurrent code - if you’re interested in learning more ways to unit test concurrent code – I’ll be speaking about this topic at upcoming NDC London next week.

Until then - Happy coding…

Labels: , , ,