Weigh the losses for Supervised VAE Classifier

27 Views Asked by At

I am working in the field of audio classification. Recently I have been trying to use Supervised VAE Classifier. Here is the architecture I am using:

class VAE(nn.Module):
    def __init__(self, input_shape, latent_size):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )
        self.fc_mu = nn.Linear(512*8*1, latent_size)
        self.fc_logvar = nn.Linear(512*8*1, latent_size)

        # Decoder
        self.decoder_input = nn.Linear(latent_size, 512*8*1)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, input_shape[0], kernel_size=3, stride=2, padding=1, output_padding=(1, 0)),
            nn.Sigmoid()
        )

        # Classifier
        self.clf = nn.Sequential(
            nn.Linear(latent_size, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(),
            nn.Dropout(0.25),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),

            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(),

            nn.Linear(128, 7),
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(-1, 512*8*1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 512, 8, 1)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar= self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstruction = self.decode(z)
        clf = self.clf(z)
        return reconstruction, mu, logvar, clf

Typically, for training VAEs use the sum of the BCE loss and KL Divergence. In the case of Supervised VAE Classifier where the implementation of a Cross-Entropy loss is required, how do you weigh all of the losses?

I have seen some research that uses arbitrary weights such as:

def vae_loss(recon_x, x, mu, logvar, clf, target):
    input_size = x.size(1) * x.size(2) * x.size(3)
    # BCE Loss
    BCE = F.binary_cross_entropy(recon_x.view(-1, input_size), x.view(-1, input_size), reduction='mean') #sum
    # Kullback-Leibler Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # CrossEntropy Loss
    clf_loss = F.cross_entropy(clf, target)
    return 0.001*(BCE + 3*KLD)+ clf_loss

Is there any solution to finding optimal weights?

0

There are 0 best solutions below