make model_averaging exclusive in federated learning using Python Threads

100 Views Asked by At

I am creating num_of_clients threads using the following code:

sockets_thread = []
no_of_client = 1

all_data = b""
while True:
    try:
        for i in range(no_of_client):
            connection, client_info = soc.accept() 
            print("\nNew Connection from {client_info}.".format(client_info=client_info))
            socket_thread = SocketThread(connection=connection,
                                     client_info=client_info, 
                                     buffer_size=1024,
                                     recv_timeout=100)
            sockets_thread.append(socket_thread)
        for i in range(no_of_client):    
            sockets_thread[i].start()
            sockets_thread[i].join()
    except:
        soc.close()
        print("(Timeout) Socket Closed Because no Connections Received.\n")
        break

In the run function, there are several pieces of code, as follows:

class SocketThread(object):
     def run(self):
           while True: 
                received_data, status = self.recv()
                if status == 0:
                    self.connection.close() 
                    break
     
                self.reply(received_data)

     def reply(self, received_data):
        model = SimpleASR()
        #all threads must averge the model before going to next line
        model_instance = self.model_averaging(model, model_instance)
        print("All threads completed model averging.")
        #now do rest of the things 

In the reply function, I have called one function. I want to write this code in such a way that every thread will proceed to the next line after calling this function.

Every thread must average the model and then proceed to the next line. I understood that I have to use Python condition variable. How can I do that?

The following functions must be mutually exclusive.

model_instance = self.model_averaging(model, model_instance)

Every thread will proceed to the next line after executing this piece of code.

I am writing this code as part of implementing a federated learning algorithm. enter image description here

1

There are 1 best solutions below

0
Solomon Slow On

t.join() probably does not do what you think it does:

for i in range(no_of_client):    
    sockets_thread[i].start()
    sockets_thread[i].join()

When your main thread calls sockets_thread[i].join(), that call will not return until the thread in question has finished. Your loop only starts each next thread after the previous thread has completely finished whatever it was supposed to do.

If you want the threads to run concurrently, then don't join any of them until after you have started all of them:

for i in range(no_of_client):    
    sockets_thread[i].start()
for i in range(no_of_client):    
    sockets_thread[i].join()

P.S., Some people might think it looked more pythonic if you wrote it like this:

for t in sockets_thread:
    t.start()
for t in sockets_thread:
    t.join()

Probably more pythonic still if you used Python's map function, but I'm not going there right now.