KL divergence loss goes to zero while training VAE

762 Views Asked by At

I am trying to train a supervised variational autoencoder to perform classification for a noisy dataset. I am using a fully connected encoder and decoder where uses the z as input for an MLP. I'm using the Adam optimizer with a learning rate of 1e-3. However my network Kl loss reach value of 4.4584e-04 after 5 epochs and the network does not learn anything after that. What could be the reason? Should I need to get stratified batches?

I used keras and tensorflow for implementation and used variaous embeding dimension for latent space of VAE.

1

There are 1 best solutions below

2
TheEngineerProgrammer On

From my experience, this can happen when you have the beta (coeficient that multiplies the KL) too big, and the NN is giving too much importance to the KL loss. Just shrink it to make your NN focus in training the reconstruction loss instead.