Training VAE on data from simple multivariate Gaussian leads to collapsed reconstructed distribution

34 Views Asked by At

I'm very new to VAEs, and trying to familiarise myself by first considering a simple data set sampled from a 3d Gaussian distribution with covariance [[1, 0.5, 0.2], [0.5, 1, 0.3], [0.2, 0.3, 1]] and centred at [1,2,-3].

I adapted this MNIST-applied VAE script to the one included below, and played around with all the hyperparameters I understand to be relevant (NNs architecture, choice of loss/activation functions, dimensions of latent space, batch size, size of training set, number of epochs considered).

The reconstructed distribution however consistently collapsed almost to a point (although it was centred correctly) until I started to reduce the relative weight of the KLD loss compared to the reconstruction one: with a factor of 0.4, the reconstructed distribution finally started to "spread out". (The collapse however reoccurs if I normalised the training data). However, I am now stuck again as I cannot seem to get the reconstructed distribution to fully capture the spread in the training data, even when reducing the KLD loss further (e.g. see one of projections attached).

I was hoping to understand:

  • Why is it that the collapse of the reconstructed distribution occurs/can only be prevented by reducing the weight of the KLD loss? And is it safe to reduce the weight of the KLD loss below 1? Why would the collapse reoccur when normalising the data even when reducing the KLD loss?
  • Is this related to why I cannot seem to be able to improve the recovered spread of the data?
  • The loss seems to be stubbornly settling very early regardless of the hyperparameters of choice, why could this be happening?

Many thanks in advance, any input would be really appreciated.

from urllib import request
import gzip
import numpy as np
import matplotlib.pyplot as plt
import torch
import math
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = 'cpu'

# Hyperparameters
frac_train = 1 # Fraction of data used for training
num_samples = 100000 # Number of samples
beta = 0.4
my_num_epochs = 50
my_learning_rate = 1e-3
my_batch_size = 128

N_outcomes = 3
mean = [1, 2, -3]
covariance_matrix = [[1, 0.5, 0.2],
                     [0.5, 1, 0.3],
                     [0.2, 0.3, 1]]

# Sample from the 3D Gaussian distribution
data = np.random.multivariate_normal(mean, covariance_matrix, num_samples)
#data = (data - np.min(data, axis=0)) / (np.max(data, axis=0) - np.min(data, axis=0))

# Divide available set into training and testing
Ndata = len(data)
Ntrain = round(Ndata*frac_train)

X_train = data[:Ntrain,[0,1,2]]
X_test = data[Ntrain:Ndata,[0,1,2]]

X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)

NNdim1 = 200
NNdim2 = 200
NNdim3 = 100
zdim = 12

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Set the number of hidden units
        self.num_hidden = zdim
        
        # Define the encoder part of the autoencoder
        self.encoder = nn.Sequential(
            nn.Linear(N_outcomes, NNdim1),
            nn.ReLU(),
            nn.Linear(NNdim1, NNdim2),
            nn.ReLU(),
            nn.Linear(NNdim2, NNdim3),
            nn.ReLU(),
            nn.Linear(NNdim3, self.num_hidden),
        )

        # Define the decoder part of the autoencoder
        self.decoder = nn.Sequential(
            nn.Linear(self.num_hidden, NNdim3),
            nn.ReLU(),
            nn.Linear(NNdim3, NNdim2),
            nn.ReLU(),
            nn.Linear(NNdim2, NNdim1),
            nn.ReLU(),
            nn.Linear(NNdim1, N_outcomes),
        )

    def forward(self, x):
        # Pass the input through the encoder
        encoded = self.encoder(x)
        # Pass the encoded representation through the decoder
        decoded = self.decoder(encoded)
        # Return both the encoded representation and the reconstructed output
        return encoded, decoded

class VAE(AutoEncoder):
    def __init__(self):
        super().__init__()
        # Add mu and log_var layers for reparameterization
        self.mu = nn.Sequential(
            nn.Linear(self.num_hidden, self.num_hidden),
            nn.ReLU()
        )
        self.log_var = nn.Sequential(
            nn.Linear(self.num_hidden, self.num_hidden),
            nn.ReLU()
        )

    def reparameterize(self, mu, log_var):
        # Compute the standard deviation from the log variance
        std = torch.exp(0.5 * log_var)
        # Generate random noise using the same shape as std
        eps = torch.randn_like(std)
        # Return the reparameterized sample
        return mu + eps * std

    def forward(self, x):
        # Pass the input through the encoder
        encoded = self.encoder(x)
        # Compute the mean and log variance vectors
        mu = self.mu(encoded)
        log_var = self.log_var(encoded)
        # Reparameterize the latent variable
        z = self.reparameterize(mu, log_var)
        # Pass the latent variable through the decoder
        decoded = self.decoder(z)
        # Return the encoded output, decoded output, mean, and log variance
        return encoded, decoded, mu, log_var

    def sample(self, num_samples):
        with torch.no_grad():
            # Generate random noise
            z = torch.randn(num_samples, self.num_hidden).to(device)
            # Pass the noise through the decoder to generate samples
            samples = self.decoder(z)
        # Return the generated samples
        return samples

def train_vae(X_train, learning_rate=1e-3, num_epochs=200, batch_size=128):
    # Convert the training data to PyTorch tensors
    X_train = torch.from_numpy(X_train).to(device)

    # Create the autoencoder model and optimizer
    model = VAE()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Define the loss function
    criterion = nn.MSELoss(reduction="sum")

    # Set the device to GPU if available, otherwise use CPU
    model.to(device)

    # Create a DataLoader to handle batching of the training data
    train_loader = torch.utils.data.DataLoader(
        X_train, batch_size=batch_size, shuffle=True
    )

    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch_idx, data in enumerate(train_loader):
            # Get a batch of training data and move it to the device
            data = data.to(device)

            # Forward pass
            encoded, decoded, mu, log_var = model(data)

            # Compute the loss and perform backpropagation
            KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            loss = criterion(decoded, data) + beta * KLD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the running loss
            total_loss += loss.item() * data.size(0)

        # Print the epoch loss
        epoch_loss = total_loss / len(train_loader.dataset)
        print(
            "Epoch {}/{}: loss={:.4f}".format(epoch + 1, num_epochs, epoch_loss)
        )

    # Return the trained model
    return model

model = train_vae(X_train, learning_rate=my_learning_rate, num_epochs=my_num_epochs, batch_size=my_batch_size)

save_path = 'trained_vae_model.pth'
torch.save(model.state_dict(), save_path)

# Set the model to evaluation mode
model.eval()

latent_dim = model.num_hidden

# Assuming latent_dim is the dimension of the latent space
num_generate = 100000
latent_vector = torch.randn(num_generate, latent_dim)

with torch.no_grad():
    data_generated = model.decoder(latent_vector)

# Convert the reconstructed images to numpy arrays
data_generated_np = data_generated.numpy()
0

There are 0 best solutions below