I am working in the field of audio classification. Recently I have been trying to use Supervised VAE Classifier. Here is the architecture I am using:
class VAE(nn.Module):
def __init__(self, input_shape, latent_size):
super(VAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU()
)
self.fc_mu = nn.Linear(512*8*1, latent_size)
self.fc_logvar = nn.Linear(512*8*1, latent_size)
# Decoder
self.decoder_input = nn.Linear(latent_size, 512*8*1)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, input_shape[0], kernel_size=3, stride=2, padding=1, output_padding=(1, 0)),
nn.Sigmoid()
)
# Classifier
self.clf = nn.Sequential(
nn.Linear(latent_size, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Linear(128, 7),
)
def encode(self, x):
x = self.encoder(x)
x = x.view(-1, 512*8*1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
x = self.decoder_input(z)
x = x.view(-1, 512, 8, 1)
x = self.decoder(x)
return x
def forward(self, x):
mu, logvar= self.encode(x)
z = self.reparameterize(mu, logvar)
reconstruction = self.decode(z)
clf = self.clf(z)
return reconstruction, mu, logvar, clf
Typically, for training VAEs use the sum of the BCE loss and KL Divergence. In the case of Supervised VAE Classifier where the implementation of a Cross-Entropy loss is required, how do you weigh all of the losses?
I have seen some research that uses arbitrary weights such as:
def vae_loss(recon_x, x, mu, logvar, clf, target):
input_size = x.size(1) * x.size(2) * x.size(3)
# BCE Loss
BCE = F.binary_cross_entropy(recon_x.view(-1, input_size), x.view(-1, input_size), reduction='mean') #sum
# Kullback-Leibler Divergence
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# CrossEntropy Loss
clf_loss = F.cross_entropy(clf, target)
return 0.001*(BCE + 3*KLD)+ clf_loss
Is there any solution to finding optimal weights?