This is a loss function of a personalized federated learning framework. When lambda equals 0, all clients train locally. Otherwise they adjust their parameters based on the similarity to the global model.
Version 1, server returns the global_model to every client:
def mtl_loss_fn(self, logits, labels, shared_model):
sample_loss_fn = torch.nn.CrossEntropyLoss()
mean_batch_term = sample_loss_fn(logits, labels)
shared_model = shared_model.to(device='cuda')
w_diff = torch.tensor(0., device=self.device)
for w, w_t in zip(self.model.parameters(), shared_model.parameters()):
w_diff += torch.pow(torch.norm(w - w_t), 2)
prox_term = 0.5 * self.lam * w_diff
# print(f"mean batch term: {mean_batch_term}, prox_term: {prox_term}")
return mean_batch_term + prox_term
Version 2, server returns the parameters of the global_model to every client:
def mtl_loss_fn(self, logits, labels, shared_model_parameters):
sample_loss_fn = torch.nn.CrossEntropyLoss()
mean_batch_term = sample_loss_fn(logits, labels)
for x in shared_model_parameters:
x.to(device='cuda')
w_diff = torch.tensor(0., device=self.device)
for w, w_t in zip(self.model.parameters(), shared_model_parameters):
w_diff += torch.pow(torch.norm(w - w_t), 2)
prox_term = 0.5 * self.lam * w_diff
# print(f"mean batch term: {mean_batch_term}, prox_term: {prox_term}")
return mean_batch_term + prox_term
Version 1 works well, but version 2 turns into completely local training no matter what lambda is.
I don't know why, they look the same to me..
Does not modify x in-place; instead, it returns a new tensor that has been moved to the specified device. You need to assign the value to a variable and pass that variable to the next step: