Wednesday, May 4, 2016

IResourceLoader: Balancing Semaphores

Recently I need to get balance getting resources from a restricted number of sources. So, for example...

I am getting resource R, and I have factories A, B and C creating those Rs. Each of those factories has a very limited capacity for creating those resources, and can only create two Rs at a time. It is easy to put the factories behind a semaphore and limit how many threads can be requesting resources from each factory at a time.

The challenge is evenly balancing the workload between all three factories. Also, please note that you can't just round robin the semaphores because there is no way to ensure that each operation will complete in the same amount of time.

To do this I created a generic IResourceLoader interface, and made two implementations: one to wrap a semaphore, and the other to wrap and balance a collection of IResourceLoaders. Below is the implementation, complete with unit tests; let's take a look!

Interface

public interface IResourceLoader<T>
{
    int Available { get; }
    int Count { get; }
    int MaxConcurrency { get; }
 
    Task<T> GetAsync(CancellationToken cancelToken = default(CancellationToken));
    bool TryGet(out Task<T> resource, CancellationToken cancelToken = default(CancellationToken));
}

ResourceLoader for a Single Resource

public class ResourceLoader<T> : IResourceLoader<T>
{
    private readonly object _lock = new object();
 
    private readonly Func<CancellationToken, Task<T>> _loader;
 
    private readonly SemaphoreSlim _semaphore;
 
    private int _count;
 
    public ResourceLoader(Func<CancellationToken, Task<T>> loader, int maxConcurrency)
    {
        _loader = loader;
        _semaphore = new SemaphoreSlim(maxConcurrency, maxConcurrency);
        MaxConcurrency = maxConcurrency;
    }
 
    public int Available => _semaphore.CurrentCount;
 
    public int Count => _count;
 
    public int MaxConcurrency { get; }
 
    public Task<T> GetAsync(CancellationToken cancelToken = new CancellationToken())
    {
        lock (_lock)
            return WaitAndLoadAsync(cancelToken);
    }
 
    public bool TryGet(out Task<T> resource, CancellationToken cancelToken = new CancellationToken())
    {
        lock (_lock)
        {
            if (_semaphore.CurrentCount == 0)
            {
                resource = null;
                return false;
            }
 
            resource = WaitAndLoadAsync(cancelToken);
            return true;
        }
    }
 
    private async Task<T> WaitAndLoadAsync(CancellationToken cancelToken)
    {
        Interlocked.Increment(ref _count);
        using (await _semaphore.UseWaitAsync(cancelToken).ConfigureAwait(false))
            return await _loader(cancelToken).ConfigureAwait(false);
    }
}

BalancedResourceLoader that wraps other ResourceLoaders

public class BalancedResourceLoader<T> : IResourceLoader<T>
{
    private readonly object _lock = new object();
 
    private readonly Queue<Tuple<TaskCompletionSource<T>, CancellationToken>> _queue
        = new Queue<Tuple<TaskCompletionSource<T>, CancellationToken>>();
 
    private readonly IList<IResourceLoader<T>> _resourceLoaders;
 
    private int _index;
 
    public BalancedResourceLoader(params IResourceLoader<T>[] resourceLoaders)
        : this((IList<IResourceLoader<T>>) resourceLoaders)
    {
    }
 
    public BalancedResourceLoader(IList<IResourceLoader<T>> resourceLoaders)
    {
        _resourceLoaders = resourceLoaders;
    }
 
    public int Available
    {
        get { return _resourceLoaders.Sum(r => r.Available); }
    }
 
    public int Count
    {
        get { return _resourceLoaders.Sum(r => r.Count); }
    }
 
    public int MaxConcurrency
    {
        get { return _resourceLoaders.Sum(r => r.MaxConcurrency); }
    }
 
    public Task<T> GetAsync(CancellationToken cancelToken = new CancellationToken())
    {
        lock (_lock)
        {
            Task<T> resource;
            GetOrQueue(out resource, cancelToken, true);
            return resource;
        }
    }
 
    public bool TryGet(out Task<T> resource, CancellationToken cancelToken = new CancellationToken())
    {
        lock (_lock)
            return GetOrQueue(out resource, cancelToken, false);
    }
 
    private bool GetOrQueue(out Task<T> resource, CancellationToken cancelToken, bool queueOnFailure)
    {
        var i = _index;
 
        while (true)
        {
            if (i >= _resourceLoaders.Count)
                i = 0;
 
            if (_resourceLoaders[i].TryGet(out resource, cancelToken))
            {
                resource.ContinueWith(OnResourceLoaded, cancelToken);
 
                _index++;
                return true;
            }
 
            i++;
 
            if (i != _index)
                continue;
 
            if (queueOnFailure)
            {
                var tcs = new TaskCompletionSource<T>();
                cancelToken.Register(() => tcs.TrySetCanceled());
 
                var tuple = Tuple.Create(tcs, cancelToken);
                _queue.Enqueue(tuple);
 
                resource = tcs.Task;
            }
 
            return false;
        }
    }
 
    private void OnResourceLoaded(Task<T> task)
    {
        Task<T> resource;
        Tuple<TaskCompletionSource<T>, CancellationToken> tuple;
 
        lock (_lock)
        {
            if (_queue.Count == 0)
                return;
 
            tuple = _queue.Peek();
 
            if (!GetOrQueue(out resource, tuple.Item2, false))
                return;
 
            _queue.Dequeue();
        }
 
        resource.ContinueWith(t => tuple.Item1.SetFromTask(t));
    }
}
 
public static class Extensions
{
    public static void SetFromTask<T>(this TaskCompletionSource<T> tcs, Task<T> task)
    {
        if (!task.IsCompleted)
            throw new ArgumentException("Task must be complete");
 
        if (task.IsCanceled)
            tcs.TrySetCanceled();
        else if (task.IsFaulted)
        {
            var ex = (Exception)task.Exception ?? new InvalidOperationException("Faulted Task");
            tcs.TrySetException(ex);
        }
        else
            tcs.TrySetResult(task.Result);
    }
}

ResourceLoader Unit Tests

public class ResourceLoaderTests
{
    [Fact]
    public async Task Success()
    {
        var count = 0;
 
        var resourceLoader = new ResourceLoader<int>(async t =>
        {
            await Task.Delay(100, t).ConfigureAwait(false);
            return Interlocked.Increment(ref count);
        }, 2);
 
        var tasks = Enumerable
            .Range(1, 5)
            .Select(i => resourceLoader.GetAsync())
            .ToArray();
 
        await Task.WhenAll(tasks).ConfigureAwait(false);
 
        var sum = tasks.Sum(t => t.Result);
        Assert.Equal(15, sum);
    }
 
    [Fact]
    public async Task Failure()
    {
        var count = 0;
 
        var resourceLoader = new ResourceLoader<int>(async t =>
        {
            await Task.Delay(100, t).ConfigureAwait(false);
            return Interlocked.Increment(ref count);
        }, 2);
 
        Task<int> task1, task2, task3;
 
        Assert.True(resourceLoader.TryGet(out task1));
        Assert.NotNull(task1);
            
        Assert.True(resourceLoader.TryGet(out task2));
        Assert.NotNull(task1);
 
        Assert.False(resourceLoader.TryGet(out task3));
        Assert.Null(task3);
 
        await Task.WhenAll(task1, task2).ConfigureAwait(false);
 
        var sum = task1.Result + task2.Result;
        Assert.Equal(3, sum);
    }
 
    [Fact]
    public async Task Cancel()
    {
        var count = 0;
 
        var resourceLoader = new ResourceLoader<int>(async t =>
        {;
            var result = Interlocked.Increment(ref count);
            await Task.Delay(100, t).ConfigureAwait(false);
            return result;
        }, 2);
 
        using (var cancelSource = new CancellationTokenSource())
        {
            var tasks = Enumerable
                .Range(1, 5)
                .Select(i => resourceLoader.GetAsync(cancelSource.Token))
                .ToArray();
 
            await Task.Delay(150, cancelSource.Token).ConfigureAwait(false);
 
            cancelSource.Cancel();
 
            await Assert
                .ThrowsAsync<TaskCanceledException>(() => Task.WhenAll(tasks))
                .ConfigureAwait(false);
 
            Assert.Equal(3, tasks.Count(t => t.IsCanceled));
        }
    }
 
    [Fact]
    public async Task Fault()
    {
        var count = 0;
 
        var resourceLoader = new ResourceLoader<int>(async cancelToken =>
        {
            await Task.Delay(100, cancelToken).ConfigureAwait(false);
 
            var result = Interlocked.Increment(ref count);
            if (result%2 == 0)
                throw new InvalidProgramException();
 
            return result;
        }, 2);
 
        var tasks = Enumerable
            .Range(1, 5)
            .Select(i => resourceLoader.GetAsync())
            .ToArray();
            
        await Assert
            .ThrowsAsync<InvalidProgramException>(() => Task.WhenAll(tasks))
            .ConfigureAwait(false);
 
        Assert.Equal(2, tasks.Count(t => t.IsFaulted));
        Assert.Equal(3, tasks.Count(t => !t.IsFaulted));
        Assert.Equal(9, tasks.Where(t => !t.IsFaulted).Sum(t => t.Result));
    }
}

BalancedResourceLoader Unit Tests

public class BalancedResourceLoaderTests
{
    [Fact]
    public async Task Success()
    {
        var queue = new ConcurrentQueue<char>();
 
        var countA = 0;
        var resourceLoaderA = new ResourceLoader<int>(async t =>
        {
            queue.Enqueue('A');
            await Task.Delay(50, t).ConfigureAwait(false);
            return Interlocked.Increment(ref countA);
        }, 2);
 
        var countB = 100;
        var resourceLoaderB = new ResourceLoader<int>(async t =>
        {
            queue.Enqueue('B');
            await Task.Delay(150, t).ConfigureAwait(false);
            return Interlocked.Increment(ref countB);
        }, 3);
 
        var countC = 10000;
        var resourceLoaderC = new ResourceLoader<int>(async t =>
        {
            queue.Enqueue('C');
            await Task.Delay(75, t).ConfigureAwait(false);
            return Interlocked.Increment(ref countC);
        }, 1);
 
        var balancedLoader = new BalancedResourceLoader<int>(
            resourceLoaderA,
            resourceLoaderB,
            resourceLoaderC);
 
        var tasks = Enumerable
            .Range(1, 10)
            .Select(i => balancedLoader.GetAsync())
            .ToArray();
 
        await Task.WhenAll(tasks).ConfigureAwait(false);
 
        Assert.Equal(5, resourceLoaderA.Count);
        Assert.Equal(3, resourceLoaderB.Count);
        Assert.Equal(2, resourceLoaderC.Count);
        Assert.Equal(10, balancedLoader.Count);
 
        var sum = tasks.Sum(t => t.Result);
        Assert.Equal(20324, sum);
 
        var order = new string(queue.ToArray());
        Assert.Equal("ABCABBAACA", order);
    }
}

Enjoy,
Tom

1 comment:

  1. Hi Tom, recently I have developed a tool which could automatically incorporate logger into the app, if you use logger very often probably it might be interesting for you. If so write me to soldmitr /at\ Gmail.com

    ReplyDelete

Real Time Web Analytics