I am trying to reproduce the training of ESRGAN (or real ESRGAN : https://github.com/xinntao/ESRGAN) with a simplified code from the original, which is quite complex. I have to train it on my dataset, a biological dataset. However :
- I have seen different codes where
tf.GradientTape(persistent=True)(heretape) is the same for both the discriminator and the generator ; - In general, it seems that they must have differents
tape.
So I do not know if ESRGAN is an isolated case.
I wanted to know whether the following code is adequate, or completely wrong.
In case of, here is the link to the global repository in github : https://github.com/SalomePx/ESRGAN2
def train_step(lr, hr):
with tf.GradientTape(persistent=True) as tape:
sr = generator(lr, training=True)
hr_output = discriminator(hr, training=True)
sr_output = discriminator(sr, training=True)
losses_D = {}
losses_D['reg'] = tf.reduce_sum(discriminator.losses)
losses_D['gan'] = dis_loss_fn(hr_output, sr_output)
losses_G = {}
losses_G['reg'] = tf.reduce_sum(generator.losses)
losses_G['pixel'] = 1e-2 * pixel_loss_fn(hr, sr)
losses_G['feature'] = 1.0 * fea_loss_fn(hr, sr)
losses_G['gan'] = 5e-3 * gen_loss_fn(hr_output, sr_output)
total_loss_G = tf.add_n([l for l in losses_G.values()])
total_loss_D = tf.add_n([l for l in losses_D.values()])
grads_G = tape.gradient(total_loss_G, generator.trainable_variables)
grads_D = tape.gradient(total_loss_D, discriminator.trainable_variables)
optimizer_G.apply_gradients(zip(grads_G, generator.trainable_variables))
optimizer_D.apply_gradients(zip(grads_D, discriminator.trainable_variables))
return total_loss_G, total_loss_D, losses_G, losses_D