How to implement the BlockingCollection.TakeFromAny equivalent for Channels?

232 Views Asked by At

I am trying to implement an asynchronous method that takes an array of ChannelReader<T>s, and takes a value from any of the channels that has an item available. It is a method with similar functionality with the BlockingCollection<T>.TakeFromAny method, that has this signature:

public static int TakeFromAny(BlockingCollection<T>[] collections, out T item,
    CancellationToken cancellationToken);

This method returns the index in the collections array from which the item was removed. An async method cannot have out parameters, so the API that I am trying to implement is this:

public static Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default);

The TakeFromAnyAsync<T> method should read asynchronously an item, and return the consumed item along with the index of the associated channel in the channelReaders array. In case all the channels are completed (either successfully or with an error), or all become complete during the await, the method should throw asynchronously a ChannelClosedException.

My question is: how can I implement the TakeFromAnyAsync<T> method? The implementation looks quite tricky. It is obvious that under no circumstances the method should consume more than one items from the channels. Also it should not leave behind fire-and-forget tasks, or let disposable resources undisposed. The method will be typically called in a loop, so it should also be reasonably efficient. It should have complexity not worse than O(n), where n in the number of the channels.

As an insight of where this method can be useful, you could take a look at the select statement of the Go language. From the tour:

The select statement lets a goroutine wait on multiple communication operations.

A select blocks until one of its cases can run, then it executes that case. It chooses one at random if multiple are ready.

select {
case msg1 := <-c1:
    fmt.Println("received", msg1)
case msg2 := <-c2:
    fmt.Println("received", msg2)
}

In the above example either a value will be taken from the channel c1 and assigned to the variable msg1, or a value will be taken from the channel c2 and assigned to the variable msg2. The Go select statement is not restricted to reading from channels. It can include multiple heterogeneous cases like writing to bounded channels, waiting for timers etc. Replicating the full functionality of the Go select statement is beyond the scope of this question.

3

There are 3 best solutions below

14
alexm On BEST ANSWER

I came up with something like this:


public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default)
{
    if (channelReaders == null)
    {
        throw new ArgumentNullException(nameof(channelReaders));
    }

    if (channelReaders.Length == 0)
    {
        throw new ArgumentException("The list cannot be empty.", nameof(channelReaders));
    }

    if (channelReaders.Length == 1)
    {
        return (await channelReaders[0].ReadAsync(cancellationToken), 0);
    }

    // First attempt to read an item synchronosuly 
    for (int i = 0; i < channelReaders.Length; ++i)
    {
        if (channelReaders[i].TryRead(out var item))
        {
            return (item, i);
        }
    }

    using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
    {

        var waitToReadTasks = channelReaders
                .Select(it => it.WaitToReadAsync(cts.Token).AsTask())
                .ToArray();

        var pendingTasks = new List<Task<bool>>(waitToReadTasks);

        while (pendingTasks.Count > 1)
        {
            var t = await Task.WhenAny(pendingTasks);

            if (t.IsCompletedSuccessfully && t.Result)
            {
                int index = Array.IndexOf(waitToReadTasks, t);
                var reader = channelReaders[index];

                // Attempt to read an item synchronosly
                if (reader.TryRead(out var item))
                {
                    if (pendingTasks.Count > 1)
                    {
                        // Cancel pending "wait to read" on the remaining readers
                        // then wait for the completion 
                        try
                        {
                            cts.Cancel();
                            await Task.WhenAll((IEnumerable<Task>)pendingTasks);
                        }
                        catch { }
                    }
                    return (item, index);
                }

                // Due to the race condition item is no longer available
                if (!reader.Completion.IsCompleted)
                {
                    // .. but the channel appears to be still open, so we retry
                    var waitToReadTask = reader.WaitToReadAsync(cts.Token).AsTask();
                    waitToReadTasks[index] = waitToReadTask;
                    pendingTasks.Add(waitToReadTask);
                }

            }

            // Remove all completed tasks that could not yield 
            pendingTasks.RemoveAll(tt => tt == t || 
                tt.IsCompletedSuccessfully && !tt.Result || 
                tt.IsFaulted || tt.IsCanceled);

        }

        int lastIndex = 0;
        if (pendingTasks.Count > 0)
        {
            lastIndex = Array.IndexOf(waitToReadTasks, pendingTasks[0]);
            await pendingTasks[0];
        }

        var lastItem = await channelReaders[lastIndex].ReadAsync(cancellationToken);
        return (lastItem, lastIndex);
    }
}

0
Theodor Zoulias On

Here is another approach. This implementation is conceptually the same with alexm's implementation, until the point where no channel has an item available immediately. Then it differs by avoiding the Task.WhenAny-in-a-loop pattern, and instead starts an asynchronous loop for each channel. All loops are racing to update a shared ValueTuple<T, int, bool> consumed variable, which is updated in a critical region, in order to prevent consuming an element from more than one channels.

/// <summary>
/// Takes an item asynchronously from any one of the specified channel readers.
/// </summary>
public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default)
{
    ArgumentNullException.ThrowIfNull(channelReaders);
    if (channelReaders.Length == 0) throw new ArgumentException(
        $"The {nameof(channelReaders)} argument is a zero-length array.");
    foreach (var cr in channelReaders) if (cr is null) throw new ArgumentException(
        $"The {nameof(channelReaders)} argument contains at least one null element.");

    cancellationToken.ThrowIfCancellationRequested();

    // Fast path (at least one channel has an item available immediately)
    for (int i = 0; i < channelReaders.Length; i++)
        if (channelReaders[i].TryRead(out var item))
            return (item, i);

    // Slow path (all channels are currently empty)
    using var linkedCts = CancellationTokenSource
        .CreateLinkedTokenSource(cancellationToken);

    (T Item, int Index, bool HasValue) consumed = default;

    Task[] tasks = channelReaders.Select(async (channelReader, index) =>
    {
        while (true)
        {
            try
            {
                if (!await channelReader.WaitToReadAsync(linkedCts.Token)
                    .ConfigureAwait(false)) break;
            }
            // Only the exceptional cases below are normal.
            catch (OperationCanceledException)
                when (linkedCts.IsCancellationRequested) { break; }
            catch when (channelReader.Completion.IsCompleted
                && !channelReader.Completion.IsCompletedSuccessfully) { break; }

            // This channel has an item available now.
            lock (linkedCts)
            {
                if (consumed.HasValue)
                    return; // An item has already been consumed from another channel.

                if (!channelReader.TryRead(out var item))
                    continue; // We lost the race to consume the available item.

                consumed = (item, index, true); // We consumed an item successfully.
            }
            linkedCts.Cancel(); // Cancel the other tasks.
            return;
        }
    }).ToArray();

    // The tasks should never fail. If a task ever fails, we have a bug.
    try { foreach (var task in tasks) await task.ConfigureAwait(false); }
    catch (Exception ex) { Debug.Fail("Unexpected error", ex.ToString()); throw; }

    if (consumed.HasValue)
        return (consumed.Item, consumed.Index);
    cancellationToken.ThrowIfCancellationRequested();
    Debug.Assert(channelReaders.All(cr => cr.Completion.IsCompleted));
    throw new ChannelClosedException();
}

It should be noted that this solution, as well as alexm's solution, depends on canceling en masse all pending WaitToReadAsync operations when an element has been consumed. Unfortunately this triggers the infamous memory leak issue that affects .NET channels with idle producers. When any async operation on a channel is canceled, the canceled operation remains in memory, attached to the internal structures of the channel, until an element is written to the channel. This behavior has been triaged by Microsoft as by-design, although the possibility of improving it has not been ruled out. Interestingly this ambiguity makes this effect not eligible for documentation. So the only way to get informed about this is by chance, either by reading about it from unofficial sources, or by falling into it.

4
Panagiotis Kanavos On

The problem is a lot easier if channels are used the way they're used in Go: Channel(Readers) as input, Channel(Readers) as output.

IEnumerable<ChannelReader<T>> sources=....;
await foreach(var msg in sources.TakeFromAny(token))
{
....
}

or

var merged=sources.TakeFromAny(token);
...
var msg=await merged.ReadAsync(token);

In this case, the input from all channel readers is copied to a single output channel. The return value of the method is the ChannelReader of this channel.

CopyToAsync helper

A CopyToAsync function can be used to copy messages from an input source to the output channel:

static async Task CopyToAsync<T>(
        this ChannelReader<T> input,
        ChannelWriter<T> output,
        CancellationToken token=default)
{
   while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
   {
         //Early exit if cancellation is requested
         while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
         {
             await output.WriteAsync(msg,token);
         }
   }
}

This code is similar to ReadAllAsync but exits immediately if cancellation is requested. ReadAllAsync will return all available items even if cancellation is requested. The methods used, including

WriteAsync doesn't throw if the channels are closed, which makes error handling a lot easier.

Error Handling and Railway-oriented programming

WaitToReadAsync does throw if the source faults but that exception and that exception will be propagated to the calling method and through Task.WhenAll to the output channel.

This can be a bit messy because it interrupts the entire pipeline. To avoid this, the error could be swallowed or logged inside CopyToAsync. An even better option would be to use Railway-oriented programming and wrap all messages in a Result<TMsg,TError> class eg :

static async Task CopyToAsync<Result<T,Exception>>(
        this ChannelReader<Result<T,Exception>> input,
        ChannelWriter<Result<T,Exception>> output,
        CancellationToken token=default)
{
   try
   {
     while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
     {
         //Early exit if cancellation is requested
         while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
         {
             var newMsg=Result.FromValue(msg);
             await output.WriteAsync(newMsg,token);
         }
     }
  }
  catch(Exception exc)
  {
    output.TryWrite(Result<T>.FromError(exc));
  }
}

TakeFromAsync

TakeFromAny (MergeAsync may be a better name) can be:

static ChannelReader<T> TakeFromAny(
        this IEnumerable<ChannelReader<T> inputs,
        CancellationToken token=default)
{
    var outChannel=Channel.CreateBounded<T>(1);

    var readers=inputs.Select(rd=>CopyToAsync(rd,outChannel,token));

    _ = Task.WhenAll(readers)
            .ContinueWith(t=>outChannel.TryComplete(t.Exception));
    return outChannel;
}

Using a bounded capacity of 1 ensures the backpressure behavior of downstream code doesn't change.

Adding a source index

This can be adjusted to emit the index of the source as well:

static async Task CopyToAsync<T>(
        this ChannelReader<T> input,int index,
        ChannelWriter<(T,int)> output,
        CancellationToken token=default)
{
  while (await input.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  {
        while (!token.IsCancellationRequested &&  input.TryRead(out T? msg))
        {

            await output.WriteAsync((msg,index),token);
        }
  }
}

static ChannelReader<(T,int)> TakeFromAny(
        this IEnumerable<ChannelReader<T> inputs,
        CancellationToken token=default)
{
    var outChannel=Channel.CreateBounded<(int,T)>(1);

    var readers=inputs.Select((rd,idx)=>CopyToAsync(rd,idx,outChannel,token));

    _ = Task.WhenAll(readers)
            .ContinueWith(t=>outChannel.TryComplete(t.Exception));
    return outChannel;
}