How can I reinitialize the Batchnorm Layer?

76 Views Asked by At

I have a CNN with batchnorm layers. I am trying to train the CNN for a few epochs, and then I want to reset the batchnorm weights (moving_mean and moving_variance) while preserving the learned CNN weights.

Is there a way of doing this?

I thought of using build_from_config (reference), but in Keras, batchnorm doesn't store its input shape in the configuration dictionary (you can see the code here).

1

There are 1 best solutions below

0
Liam F-A On

I think I found a way to do this, but it's probably a bit unorthodox since it uses Keras private variables from base_layer.py.

    for layer in model.layers:  #Find the Batch Norm Layers in the Model
         if layer.__class__.__name__ == 'BatchNormalization':
              layer.build(layer._build_input_shape)

I'll leave the question open in case a better (more "Pythonic") solution exists.