I'm very new to VAEs, and trying to familiarise myself by first considering a simple data set sampled from a 3d Gaussian distribution with covariance [[1, 0.5, 0.2], [0.5, 1, 0.3], [0.2, 0.3, 1]] and centred at [1,2,-3].
I adapted this MNIST-applied VAE script to the one included below, and played around with all the hyperparameters I understand to be relevant (NNs architecture, choice of loss/activation functions, dimensions of latent space, batch size, size of training set, number of epochs considered).
The reconstructed distribution however consistently collapsed almost to a point (although it was centred correctly) until I started to reduce the relative weight of the KLD loss compared to the reconstruction one: with a factor of 0.4, the reconstructed distribution finally started to "spread out". (The collapse however reoccurs if I normalised the training data). However, I am now stuck again as I cannot seem to get the reconstructed distribution to fully capture the spread in the training data, even when reducing the KLD loss further (e.g. see one of projections attached).
I was hoping to understand:
- Why is it that the collapse of the reconstructed distribution occurs/can only be prevented by reducing the weight of the KLD loss? And is it safe to reduce the weight of the KLD loss below 1? Why would the collapse reoccur when normalising the data even when reducing the KLD loss?
- Is this related to why I cannot seem to be able to improve the recovered spread of the data?
- The loss seems to be stubbornly settling very early regardless of the hyperparameters of choice, why could this be happening?
Many thanks in advance, any input would be really appreciated.
from urllib import request
import gzip
import numpy as np
import matplotlib.pyplot as plt
import torch
import math
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
device = 'cpu'
# Hyperparameters
frac_train = 1 # Fraction of data used for training
num_samples = 100000 # Number of samples
beta = 0.4
my_num_epochs = 50
my_learning_rate = 1e-3
my_batch_size = 128
N_outcomes = 3
mean = [1, 2, -3]
covariance_matrix = [[1, 0.5, 0.2],
[0.5, 1, 0.3],
[0.2, 0.3, 1]]
# Sample from the 3D Gaussian distribution
data = np.random.multivariate_normal(mean, covariance_matrix, num_samples)
#data = (data - np.min(data, axis=0)) / (np.max(data, axis=0) - np.min(data, axis=0))
# Divide available set into training and testing
Ndata = len(data)
Ntrain = round(Ndata*frac_train)
X_train = data[:Ntrain,[0,1,2]]
X_test = data[Ntrain:Ndata,[0,1,2]]
X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)
NNdim1 = 200
NNdim2 = 200
NNdim3 = 100
zdim = 12
class AutoEncoder(nn.Module):
def __init__(self):
super().__init__()
# Set the number of hidden units
self.num_hidden = zdim
# Define the encoder part of the autoencoder
self.encoder = nn.Sequential(
nn.Linear(N_outcomes, NNdim1),
nn.ReLU(),
nn.Linear(NNdim1, NNdim2),
nn.ReLU(),
nn.Linear(NNdim2, NNdim3),
nn.ReLU(),
nn.Linear(NNdim3, self.num_hidden),
)
# Define the decoder part of the autoencoder
self.decoder = nn.Sequential(
nn.Linear(self.num_hidden, NNdim3),
nn.ReLU(),
nn.Linear(NNdim3, NNdim2),
nn.ReLU(),
nn.Linear(NNdim2, NNdim1),
nn.ReLU(),
nn.Linear(NNdim1, N_outcomes),
)
def forward(self, x):
# Pass the input through the encoder
encoded = self.encoder(x)
# Pass the encoded representation through the decoder
decoded = self.decoder(encoded)
# Return both the encoded representation and the reconstructed output
return encoded, decoded
class VAE(AutoEncoder):
def __init__(self):
super().__init__()
# Add mu and log_var layers for reparameterization
self.mu = nn.Sequential(
nn.Linear(self.num_hidden, self.num_hidden),
nn.ReLU()
)
self.log_var = nn.Sequential(
nn.Linear(self.num_hidden, self.num_hidden),
nn.ReLU()
)
def reparameterize(self, mu, log_var):
# Compute the standard deviation from the log variance
std = torch.exp(0.5 * log_var)
# Generate random noise using the same shape as std
eps = torch.randn_like(std)
# Return the reparameterized sample
return mu + eps * std
def forward(self, x):
# Pass the input through the encoder
encoded = self.encoder(x)
# Compute the mean and log variance vectors
mu = self.mu(encoded)
log_var = self.log_var(encoded)
# Reparameterize the latent variable
z = self.reparameterize(mu, log_var)
# Pass the latent variable through the decoder
decoded = self.decoder(z)
# Return the encoded output, decoded output, mean, and log variance
return encoded, decoded, mu, log_var
def sample(self, num_samples):
with torch.no_grad():
# Generate random noise
z = torch.randn(num_samples, self.num_hidden).to(device)
# Pass the noise through the decoder to generate samples
samples = self.decoder(z)
# Return the generated samples
return samples
def train_vae(X_train, learning_rate=1e-3, num_epochs=200, batch_size=128):
# Convert the training data to PyTorch tensors
X_train = torch.from_numpy(X_train).to(device)
# Create the autoencoder model and optimizer
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Define the loss function
criterion = nn.MSELoss(reduction="sum")
# Set the device to GPU if available, otherwise use CPU
model.to(device)
# Create a DataLoader to handle batching of the training data
train_loader = torch.utils.data.DataLoader(
X_train, batch_size=batch_size, shuffle=True
)
# Training loop
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx, data in enumerate(train_loader):
# Get a batch of training data and move it to the device
data = data.to(device)
# Forward pass
encoded, decoded, mu, log_var = model(data)
# Compute the loss and perform backpropagation
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
loss = criterion(decoded, data) + beta * KLD
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update the running loss
total_loss += loss.item() * data.size(0)
# Print the epoch loss
epoch_loss = total_loss / len(train_loader.dataset)
print(
"Epoch {}/{}: loss={:.4f}".format(epoch + 1, num_epochs, epoch_loss)
)
# Return the trained model
return model
model = train_vae(X_train, learning_rate=my_learning_rate, num_epochs=my_num_epochs, batch_size=my_batch_size)
save_path = 'trained_vae_model.pth'
torch.save(model.state_dict(), save_path)
# Set the model to evaluation mode
model.eval()
latent_dim = model.num_hidden
# Assuming latent_dim is the dimension of the latent space
num_generate = 100000
latent_vector = torch.randn(num_generate, latent_dim)
with torch.no_grad():
data_generated = model.decoder(latent_vector)
# Convert the reconstructed images to numpy arrays
data_generated_np = data_generated.numpy()