everybody. I have to implement a simple Thread Pool in C# for my studies. I have already made it, but it is very bad and not productive.
Here is the algorithm by which I have to realize everything.
- An external thread places tasks into a shared queue
- Workers place tasks in the local queue
Workers look for tasks in the following order:
- First in their local queue
- Then in the shared queue
- And finally in the local queues of other Workers.
- If no task is found, the Worker will wait for a new task to appear in any of the queues.
To get started we are presented with a ready interface IThreadPool.cs, a class for stealing tasks WorkStealingQueue.cs and DotNetThreadPoolWrapper.cs. There's also a file with tests on which I test my algorithm. Of course all this code is not changeable, I just have to implement it.
This is my results While running a test, my laptop just blows up and slows down
This is my code:
public class ThreadPoolCustom : IThreadPool
{
private readonly ConcurrentQueue<Action> _globalQueue = new ConcurrentQueue<Action>();
private readonly WorkStealingQueue<Action>[] _localQueues;
private readonly List<Thread> _workers = new List<Thread>();
private volatile int _counter = 0;
private object _lock = new object();
public ThreadPoolCustom(int workerCount)
{
_localQueues = new WorkStealingQueue<Action>[workerCount];
for (int i = 0; i < workerCount; i++)
{
_localQueues[i] = new WorkStealingQueue<Action>();
var worker = new Thread(Worker);
worker.IsBackground = true;
worker.Start(i);
_workers.Add(worker);
}
}
public ThreadPoolCustom() : this(Environment.ProcessorCount)
{
}
public void EnqueueAction(Action action)
{
// tasks in globak queue
_globalQueue.Enqueue(action);
// local queue
foreach (var localQueue in _localQueues)
{
localQueue.LocalPush(action);
}
lock (_lock)
{
Monitor.PulseAll(_lock);
}
}
public long GetTasksProcessedCount()
{
return _counter;
}
private void Worker(object data)
{
var workerId = (int)data;
while (true)
{
Action task = null;
// Trying to retrieve a task from the local queue
if (_localQueues[workerId].LocalPop(ref task))
{
ExecuteTask(task);
continue;
}
// Trying to retrieve a task from the global queue
if (TryDequeueFromGlobalQueue(ref task))
{
ExecuteTask(task);
continue;
}
// Trying to Steal a task from the other loval queues
if (TryStealFromOtherQueues(workerId, ref task))
{
ExecuteTask(task);
continue;
}
// If there are no tasks, I wait for a signal to appear in other queues.
lock (_lock)
{
Monitor.Wait(_lock);
}
}
}
private bool TryDequeueFromGlobalQueue(ref Action task)
{
return _globalQueue.TryDequeue(out task);
}
private bool TryStealFromOtherQueues(int workerId, ref Action task)
{
// Going to every local queue except my own, trying to steal
for (int i = 0; i < _localQueues.Length; i++)
{
if (i == workerId) continue;
if (_localQueues[i]?.TrySteal(ref task) ?? false)
{
return true;
}
}
return false;
}
private void ExecuteTask(Action task)
{
task.Invoke();
Interlocked.Increment(ref _counter);
}
}
Do not modify the interface code
public interface IThreadPool
{
void EnqueueAction(Action action);
long GetTasksProcessedCount();
}
Do not modify the wrapper code
public class DotNetThreadPoolWrapper : IThreadPool
{
private long processedTask = 0L;
public void EnqueueAction(Action action)
{
ThreadPool.UnsafeQueueUserWorkItem(delegate
{
action.Invoke();
Interlocked.Increment(ref processedTask);
}, null);
}
public long GetTasksProcessedCount() => processedTask;
}
Do not modify the WorkStealingQueue
public class WorkStealingQueue<T>
{
private const int INITIAL_SIZE = 32;
private T[] m_array = new T[INITIAL_SIZE];
private int m_mask = INITIAL_SIZE - 1;
private volatile int m_headIndex = 0;
private volatile int m_tailIndex = 0;
private readonly object m_foreignLock = new object();
public bool IsEmpty => m_headIndex >= m_tailIndex;
public int Count => m_tailIndex - m_headIndex;
public void LocalPush(T obj)
{
var tail = m_tailIndex;
if(tail < m_headIndex + m_mask)
{
m_array[tail & m_mask] = obj;
m_tailIndex = tail + 1;
}
else
{
lock (m_foreignLock)
{
var head = m_headIndex;
var count = m_tailIndex - m_headIndex;
if(count >= m_mask)
{
var newArray = new T[m_array.Length << 1];
for(var i = 0; i < m_array.Length; i++)
{
newArray[i] = m_array[(i + head) & m_mask];
}
m_array = newArray;
m_headIndex = 0;
m_tailIndex = tail = count;
m_mask = (m_mask << 1) | 1;
}
m_array[tail & m_mask] = obj;
m_tailIndex = tail + 1;
}
}
}
public bool LocalPop(ref T obj)
{
var tail = m_tailIndex;
if(m_headIndex >= tail)
{
return false;
}
tail -= 1;
Interlocked.Exchange(ref m_tailIndex, tail);
if(m_headIndex <= tail)
{
obj = m_array[tail & m_mask];
return true;
}
else
{
lock (m_foreignLock)
{
if(m_headIndex <= tail)
{
obj = m_array[tail & m_mask];
return true;
}
else
{
m_tailIndex = tail + 1;
return false;
}
}
}
}
public bool TrySteal(ref T obj)
{
var taken = false;
try
{
taken = Monitor.TryEnter(m_foreignLock);
if(taken)
{
var head = m_headIndex;
Interlocked.Exchange(ref m_headIndex, head + 1);
if(head < m_tailIndex)
{
obj = m_array[head & m_mask];
return true;
}
else
{
m_headIndex = head;
return false;
}
}
}
finally
{
if(taken)
{
Monitor.Exit(m_foreignLock);
}
}
return false;
}
}
And a class testing our algorithm
public class ThreadPoolTests
{
public static void Run<TThreadPool>() where TThreadPool : IThreadPool, new()
{
Run(() => new TThreadPool());
}
public static void Run(Func<IThreadPool> threadPoolFactory)
{
var name = threadPoolFactory().GetType().Name.Replace("ThreadPool", "", StringComparison.OrdinalIgnoreCase);
Console.WriteLine($"----------======={name} ThreadPool tests=======----------");
RunTest(LongCalculations);
RunTest(ShortCalculations);
RunTest(ExtremelyShortCalculations);
RunTest(InnerShortCalculations);
RunTest(InnerExtremelyShortCalculations);
Console.WriteLine("\n");
void RunTest(Action<IThreadPool> test) => test(threadPoolFactory());
}
private static void LongCalculations(IThreadPool threadPool)
{
Console.Write("LongCalculations test: ");
var timer = Stopwatch.StartNew();
long enqueueMs;
const int actionsCount = 1 * 1000;
using(var cev = new CountdownEvent(actionsCount))
{
Action sumAction = () =>
{
cev.Signal();
Thread.SpinWait(1000 * 1000);
};
for(int i = 0; i < actionsCount; i++)
{
threadPool.EnqueueAction(sumAction);
}
enqueueMs = timer.ElapsedMilliseconds;
cev.Wait();
}
timer.Stop();
Console.WriteLine($" total {timer.ElapsedMilliseconds} ms, enqueue {enqueueMs} ms [tasks processed ~{threadPool.GetTasksProcessedCount()}]");
}
private static void ShortCalculations(IThreadPool threadPool)
{
Console.Write("ShortCalculations test: ");
var timer = Stopwatch.StartNew();
long enqueueMs;
const int actionsCount = 1 * 1000 * 1000;
using(var cev = new CountdownEvent(actionsCount))
{
Action sumAction = () =>
{
cev.Signal();
Thread.SpinWait(1000);
};
for(var i = 0; i < actionsCount; i++)
{
threadPool.EnqueueAction(sumAction);
}
enqueueMs = timer.ElapsedMilliseconds;
cev.Wait();
}
timer.Stop();
Console.WriteLine($" total {timer.ElapsedMilliseconds} ms, enqueue {enqueueMs} ms [tasks processed ~{threadPool.GetTasksProcessedCount()}]");
}
private static void ExtremelyShortCalculations(IThreadPool threadPool)
{
Console.Write("ExtremelyShortCalculations test: ");
var timer = Stopwatch.StartNew();
long enqueueMs;
const int actionsCount = 1 * 1000 * 1000;
using(var cev = new CountdownEvent(actionsCount))
{
Action sumAction = () =>
{
cev.Signal();
};
for(int i = 0; i < actionsCount; i++)
{
threadPool.EnqueueAction(sumAction);
}
enqueueMs = timer.ElapsedMilliseconds;
cev.Wait();
}
timer.Stop();
Console.WriteLine($" total {timer.ElapsedMilliseconds} ms, enqueue {enqueueMs} ms [tasks processed ~{threadPool.GetTasksProcessedCount()}]");
}
private static void InnerShortCalculations(IThreadPool threadPool)
{
Console.Write("InnerCalculations test: ");
var timer = Stopwatch.StartNew();
long enqueueMs;
const int actionsCount = 1 * 1000;
const int subactionsCount = 1 * 1000;
using(CountdownEvent outerEvent = new CountdownEvent(actionsCount))
using(CountdownEvent innerEvent = new CountdownEvent(actionsCount * subactionsCount))
{
Action innerAction = () =>
{
innerEvent.Signal();
Thread.SpinWait(1000);
};
Action outerAction = () =>
{
for(int i = 0; i < subactionsCount; i++)
{
threadPool.EnqueueAction(innerAction);
}
outerEvent.Signal();
};
for(int i = 0; i < actionsCount; i++)
{
threadPool.EnqueueAction(outerAction);
}
outerEvent.Wait();
enqueueMs = timer.ElapsedMilliseconds;
innerEvent.Wait();
}
timer.Stop();
Console.WriteLine($" total {timer.ElapsedMilliseconds} ms, enqueue {enqueueMs} ms [tasks processed ~{threadPool.GetTasksProcessedCount()}]");
}
private static void InnerExtremelyShortCalculations(IThreadPool threadPool)
{
Console.Write("InnerExtremelyShortCalculations test: ");
var timer = Stopwatch.StartNew();
long enqueueMs;
const int actionsCount = 1 * 1000;
const int subactionsCount = 1 * 1000;
using(CountdownEvent outerEvent = new CountdownEvent(actionsCount))
using(CountdownEvent innerEvent = new CountdownEvent(actionsCount * subactionsCount))
{
Action innerAction = () =>
{
innerEvent.Signal();
};
Action outerAction = () =>
{
for(int i = 0; i < subactionsCount; i++)
{
threadPool.EnqueueAction(innerAction);
}
outerEvent.Signal();
};
for(int i = 0; i < actionsCount; i++)
{
threadPool.EnqueueAction(outerAction);
}
outerEvent.Wait();
enqueueMs = timer.ElapsedMilliseconds;
innerEvent.Wait();
}
timer.Stop();
Console.WriteLine($" total {timer.ElapsedMilliseconds} ms, enqueue {enqueueMs} ms [tasks processed ~{threadPool.GetTasksProcessedCount()}]");
}
}
Can anyone tell me what the problem with my code is? Why does the computer get sick while running.

A basic threadpool should be fairly trivial to implement with a blocking collection. Sure, you will miss out on all the work stealing etc, but if the number of threads is low it should work fine. If nothing else I would use this as a reference implementation to confirm that the tests code is working as intended, and to measure the advantages of the more complex work-stealing stuff.
It seem odd that you are pushing the action into multiple queues, the action should only be executed once, so some mechanism would be needed to remove it from the other queues, and that seem counterproductive.
My understanding is that you should check what thread is executing, and if it is one of the pooled threads, use that local queue, otherwise it goes into the global queue. Something like:
Next potential problem
I don't think this works as you probably intend it to. My understanding is this will just cause the thread to:
So chances are that all of your workers are just looping, not doing any actual work, and possibly crowding out any other thread that wants to enqueue work. I'm guessing that you want to wait until any work has been queued, and for that you might want to take a look at AutoResetEvent or ManualResetEvent.
Multithreading is difficult in general, and trying to write a correct and high performance thread pool would be very difficult. Doing it for learning is great, but I would caution against use in production.