My code does the following:
- Starts processes to collect data
- Starts processes to test model
- One thread takes care of training (read data from collect processes)
- One thread takes care of testing (read data from test processes)
- Every time the training thread does a step, it waits for the testing to also do one step
- Before doing a step, the testing thread waits for a training step
I need to have reproducible results, but there is randomness in both the processes and the threads. I naively fix the seeds in each process and thread, but results are always different.
Is it possible to have reproducible results? I know threads are non-deterministic, but I don't launch multiple threads from the same pool: I have 2 pools, each launching only 1 thread.
Below is a simple MWE. I need the output to be always the same every time I run this program.
EDIT
Using the initializer argument in all pools I can have deterministic behavior within threads and processes. However, the order in which processes write the data is random due to multiprocesses non-deterministic behavior. Sometimes one process reads the queue first and writes, sometimes it's another process.
How can I fix it?
import logging
import traceback
import torch
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from torch import multiprocessing as mp
shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
def fix_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def collect(id, queue, data):
#log.info('Collect %i started ...', id)
while True:
idx = queue.get()
if idx is None:
break
data[idx] = torch.rand(1)
log.info(f'Collector {id} got idx {idx} and sampled {data[idx]}')
queue.task_done()
#log.info('Collect %i completed', id)
def test(id, queue, data):
#log.info('Test %i started ...', id)
while True:
idx = queue.get()
if idx is None:
break
data[idx] = torch.rand(1)
log.info(f'Tester {id} got idx {idx} and sampled {data[idx]}')
queue.task_done()
#log.info('Test %i completed', id)
def run():
steps = 0
num_collect_procs = 3
num_test_procs = 2
max_steps = 10
data_collect = torch.zeros(num_collect_procs).share_memory_()
data_test = torch.zeros(num_test_procs).share_memory_()
ctx = mp.get_context('spawn')
manager = mp.Manager()
collect_queue = manager.JoinableQueue()
test_queue = manager.JoinableQueue()
train_test_queue = manager.JoinableQueue()
collect_pool = ProcessPoolExecutor(
num_collect_procs,
mp_context=ctx,
initializer=fix_seed,
initargs=(1,)
)
test_pool = ProcessPoolExecutor(
num_test_procs,
mp_context=ctx,
initializer=fix_seed,
initargs=(1,)
)
for i in range(num_collect_procs):
future = collect_pool.submit(collect, i, collect_queue, data_collect)
for i in range(num_test_procs):
future = test_pool.submit(test, i, test_queue, data_test)
def run_train():
nonlocal steps
#log.info('Training thread started ...')
while steps < max_steps:
train_test_queue.put(True)
train_test_queue.join()
for idx in range(num_collect_procs):
collect_queue.put(idx)
log.info('Training, %i %f', steps, data_collect.sum() + torch.rand(1))
collect_queue.join()
steps += 1
#log.info('Training ended')
for i in range(num_collect_procs):
collect_queue.put(None)
train_test_queue.put(None)
def run_test():
nonlocal steps
#log.info('Testing thread started ...')
while steps < max_steps:
status = train_test_queue.get()
if status is None:
break
for idx in range(num_test_procs):
test_queue.put(idx)
log.info('Testing, %i %f', steps, data_test.sum() + torch.rand(1))
test_queue.join()
train_test_queue.task_done()
#log.info('Testing ended')
for i in range(num_test_procs):
test_queue.put(None)
training_thread = ThreadPoolExecutor(1, initializer=fix_seed, initargs=(1,))
testing_thread = ThreadPoolExecutor(1, initializer=fix_seed, initargs=(1,))
training_thread.submit(run_train)
testing_thread.submit(run_test)
if __name__ == '__main__':
run()
I am not familiar with
torchand I could not easily tell whether its random number generator is sharable across all processes or whether each process has its own generator that will generate the same sequence of numbers if they are both seeded identically.Let's first assume the generator is sharable, i.e. each process is effectively making calls to the same, sharable random number generator seeded with 0, and that the first two random numbers generated for such a sequence are 9 and 11. Let's assume you have only two
collectprocesses,p1andp2. When you run the program the first time, this is the order of events:idxvalue of 0 from queuedata[0] = 9idxvalue of 1 from queuedata[1] = 11The next time you run, this is the order that events occur:
idxvalue of 0 from queue and then loses control of the CPU before it has a chance to get a random number and store itidxvalue of 1 from queuedata[1] = 9data[0] = 11Already we see the results are not duplicated. The only way to ensure duplication would be to serialize all the code between
idx = queue.get()anddata[idx] = torch.rand(1)with amultiprocessing.Semaphoreso that you guarantee any process that retrieves the Nth index is also retrieving the Nth random number from the seeded sequence. Assuming that you are doing "real" work in your actual code and that the results for a given index only depends on the random number used, then this should be doable without any performance impact. You would allocate a semaphore and use the initializer and initargs arguments to initialize each pool process with the semaphore and you would place the previously described critical section within awith semaphore:block:Let's repeat the same two scenarios where each process has its own random number generator:
First run:
idxvalue of 0 from queuedata[0] = 9idxvalue of 1 from queuedata[1] = 9Second run:
idxvalue of 0 from queue and then loses control of the CPU before it has a chance to get a random number and store itidxvalue of 1 from queuedata[1] = 9data[0] = 11The only possible way that I can see of ensuring duplicate runs is if you split all the possible indices into two groups and you have two input queues. You pass to
p1one of the queues to which you put half the indices and you pass top2the other input queue to which you have written the remaining indices. That ensures that the random number used for any index does not vary from run to run, i.e. the Nth index retrieved by a given process will be using the Nth random number. In this case you should seed each process differently to avoid computing the same results for different indices.Update
Your logic seems to be overly complicated with pools, which have their own internal queues, processes, threads, queues, etc. Frankly, I am having some difficulty in following what you are doing. But the following is my idea of achieving repeatability of results. Here I assume that the random number generator is not sharable across processes and I use an implementation for which I know that to be the case. Therefore, I create N processes and each one seeds its own random number generator uniquely to minimize the probability of multiple processes generating the same random number (which wouldn't be fatal if that were to happen, I assume). A shared array initialized to zeros is created and passed to each process. But each process has its own input queue from which it is retrieving indices and I write 1/N of the indices to each queue so that each will repeatedly generate the ith random number in the sequence its seeded random number generator will create when it retrieves the ith index that has been passed to it:
Prints: