Pytorch: How to properly parallelize forward passes through an ensemble of networks?

18 Views Asked by At

Disclaimer: This is not about data parallelism.

In Pytorch, I have an ensemble of models in a torch.nn.ModuleDict that maps string model names to torch.nn.Module objects. I place each network on a separate GPU e.g. cuda:0, cuda:1, ... and I want to make simultaneous, parallel forward passes through each model to compute the losses of each model. How do I do this?

I am currently looping:

    def compute_loss(
        self, image: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        losses_per_model: Dict[str, torch.Tensor] = {}
        total_loss = torch.zeros(1, requires_grad=True, device="cpu")
        for model_name, model_wrapper in self.models_dict.items():
            loss_for_model = model_wrapper.compute_loss(
                image=image,
            )
            losses_per_model[model_name] = loss_for_model
            total_loss = total_loss + loss_for_model.cpu()
        avg_loss = total_loss / len(self.models_dict)
        losses_per_model["avg"] = avg_loss
        return losses_per_model

Does looping perform the forward passes in parallel? If not, what changes to my code do I need to make?

0

There are 0 best solutions below