In federated learning, I want to get weights of each local model every round, then I will cluster local clients based on their weights, but I can just use training_process.get_model_weights(train_state) to get global weights only.
I did use training_process.get_model_weights(train_state) to get global weights, but I haven't found any library or function to get weights of each clients yet.
This is definitely possible. To do so, you would just need to write a
tff.federated_computationthat returns theCLIENTS-placed model weights.For brevity, I'll illustrate this in a much simpler setting, but the same principle applies to model training. For example, let's say that for each client, I'm going to take some integer broadcast from the server, and add it to the client's locally held integer, and return the results. I could do:
Then, the
add_across_clients(3, [1, 2, 5])will return the value[4, 5, 8]. In other words, it is returning atff.CLIENTS-placed value, representing the collection as a list.You can do the same kind of thing with model training code. Broadcast some weights, apply local training (via
tff.federated_map) and return the result.