I am writing a VAE that uses a PID algorithm to tune the KL-divergence of the VAE (see Shao et al 2020). In short, before calculating the total loss, the KL-divergence term is multiplied by a term beta. This term changes over the course of training according to the following equation:
e(t) is the difference between the desired KL-divergence and the current KL-divergence. I don't think the first term on the RHS is a problem, but the second term causes some problems. My implementation of this acting in the test_step() function causes the following error:
The tensor <tf.Tensor 'add_2:0' shape=() dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=train_function, id=139749247929504), which is out of scope.
Here is my current implementation:
In the __init__ of my model, I initialize an empty TensorArray to keep track of all e(t) from 0 to t, so that I can sum them later. Here is what the train_step() function looks like:
def train_step(self, data: npt.ArrayLike) -> dict:
# Set gradient context manager
with tf.GradientTape() as tape:
# Get latent values
mean, log_variance, sample = self.encoder(data)
# Reconstruct from the sample
reconstruction = self.decoder(sample)
# Calculcate reconstruction loss
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.categorical_crossentropy(data, reconstruction), axis=0
)
)
# Calculate KL Loss
kl_loss = self.kullback_leibler_loss(mean=mean, log_variance=log_variance)
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# Get error vs desired KL
error = self.desired_kl - kl_loss
# Add new error to TensorArray of errors and add to iteration
self.beta_errors = self.beta_errors.write(
self.beta_iteration_counter, error
)
# Calculate proportional term
proportional_term = self.proportional_kl / (1 + tf.exp(error))
# Calculate integral term
integral_term = self.integral_kl * tf.reduce_sum(
self.beta_errors.stack()
)
# Get control score
control_score = proportional_term - integral_term + self.derivative_kl
# Calculate total loss
total_loss = reconstruction_loss + control_score * kl_loss
# Apply gradient
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update losses
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
self.kl_beta_tracker.update_state(control_score)
# Return dictionary of losses
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
"beta_score": self.kl_beta_tracker.result(),
}
For reference, here is the call() function:
def call(self, inputs):
samples = self.encoder(inputs)
self.beta_iteration_counter += 1
return self.decoder(samples[2])
The way I wanted to implement the inclusion of beta(t) in the test_step() function is to do the following in init:
def __init__():
...
# ^^^ all other init stuff
self.betas = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
Then near the end of train_step():
def train_step():
...
# ^^^ all the other train_step stuff
self.betas = self.betas.write(self.beta_iteration_counter, control_score)
...
Lastly:
def test_step(self, data: npt.ArrayLike):
validation_data, _ = data
mean, log_variance, sample = self.encoder(validation_data)
reconstruction = self.decoder(sample)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.categorical_crossentropy(validation_data, reconstruction),
axis=0,
)
)
# Get control score
control_score = self.betas.read(self.beta_iteration_counter)
# Calculate KL Loss
kl_loss = self.kullback_leibler_loss(mean=mean, log_variance=log_variance)
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# Calculate total loss
total_loss = reconstruction_loss + control_score * kl_loss
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"total_loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
This causes the error seem above. Without this, the training works just fine.
I have also tried to put the test_step() funciton under a @tf.function decorator.
So at this point I am not sure how to get the correct validation loss to be calculated during training.
