VAE for Motion Sequence Generation - Convergence Issue with Scheduled Sampling

37 Views Asked by At

I have implemented a Variational Autoencoder (VAE) in PyTorch for motion sequence generation using human pose data (joint angles and angular velocities in radians) from the CMU dataset. The VAE architecture consists of an encoder and a decoder, each with two layers, comprised of a Conv1D layer and an ELU activation for each layer.

enter image description here

During training, I input a sequence of 121 poses (60 prev pose + current pose (p(n)) + 60 next pose in the dataset) and the VAE generates the next pose (p_hat(n+1)).

I also have tried with normalized joint angles and angular velocities but it worsens the convergence.

Here's an overview of my training process:

Loss Function:

Initially trained for 30 epochs using Mean Squared Error (MSE) loss by comparing the generated next pose with ground truth data from the CMU dataset. loss = MSE(p(n+1), p_hat(n+1))

From epoch 31 to 60, I added the KL divergence to the loss function. loss = MSE(p(n+1), p_hat(n+1)) + KL

Scheduled Sampling:

Starting from epoch 61, I applied scheduled sampling, gradually increasing the probability p from 0.0 to 1.0 over 20 epochs (epoch 61 to 80). From epoch 81 onwards, p is set to 1, implying that the generated next pose is fed into the model as the current pose to generate the next pose. The length of scheduled sampling is 8 (I autoregressively create next 8 poses, inputting the generated pose of the VAE)

The Issue: The network converges nicely on the MSE loss, a bit slower on MSE+KL, but it fails to converge when scheduled sampling is applied.

My Questions:

Is there a potential reason why the model doesn't converge during the scheduled sampling phase? Are there any adjustments or insights regarding the VAE structure or training parameters that could help resolve this issue and improve convergence during scheduled sampling?

VAE Structure and Parameters:

Encoder and Decoder: Each with two layers (Conv1D + ReLu activation) Loss: MSE initially, then MSE+KL Scheduled Sampling: Gradual increase of sampling probability p from 0.0 to 1.0 over epochs 61 to 80, then p set to 1 from epoch 81.


class Encoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, frameSequence, intermediate_channels):
        super(VariationalEncoder, self).__init__()
        
        #intermediate_channels = 256
        # layer 1
        self.convLayer1 = nn.Conv1d(in_channels = inputFeatDim,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        # layer 2
        self.convLayer2 = nn.Conv1d(in_channels = intermediate_channels + inputFeatDim, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        self.downSamepleLayer = nn.Linear(in_features= frameSequence, out_features=1, bias=True)

        self.muLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        self.logVarLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        
        self.normalDist = torch.distributions.Normal(0, 1)
        self.normalDist.loc = self.normalDist.loc.cuda()
        self.normalDist.scale = self.normalDist.scale.cuda()
        self.kullbackLeibler = 0
        self.latent = torch.zeros(1).cuda()

        
        #self.print_f = True
        

    def forward(self, x):
        input = x
        x = self.convLayer1(x)
        l1_output = x
        x = torch.relu(x)
        
        x = self.convLayer2(torch.cat((input, x),dim=1))
        x = torch.relu(x)
        
        x = self.downSamepleLayer(x)
        
        mu = self.muLayer(x) # input here must be(latentDim)
        logVar= self.logVarLayer(x)
        
        self.latent = mu + torch.exp(0.5 * logVar)*self.normalDist.sample(mu.shape)
        self.kullbackLeibler = ((torch.exp(logVar) + mu**2)/2 - 0.5 * logVar - 0.5).sum()/(logVar.size()[0]) # logVar size ----> [batch_size * latentDim * 1]
        return self.latent, self.kullbackLeibler

class Decoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, poseFeatDim, frameSequence, intermediate_channels):
        super(Decoder, self).__init__()
        self.LatentExpander = nn.Linear(in_features=latentDim, out_features=poseFeatDim)

        # entry layer
        entry_in_channels = latentDim + poseFeatDim
        self.entryLayer = nn.Conv1d(in_channels = entry_in_channels, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)

        # hidden layer 1
        self.convLayer1 = nn.Conv1d(in_channels = intermediate_channels+entry_in_channels,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        
         
    def forward(self, latent, cur_pose):
        
        cur_pose = cur_pose.unsqueeze(2)
        
        x = torch.cat([latent, cur_pose], dim = 1)
        input = x
        

        x = self.entryLayer(x)
        x = torch.relu(x)
        
        
        x = self.convLayer1(torch.cat((input, x), dim=1))
        x = torch.relu(x)
        

        x = self.finalLayer(x)
        return x
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.encoder = Encoder
        self.decoder = Decoder

    def forward(self, seq, cur_pose):
        latent, kullbackLeibler = self.encoder(seq)
        
        X_hat = self.decoder(latent, cur_pose)
        return X_hat, latent, kullbackLeibler
0

There are 0 best solutions below