Why doesn't the loss calculated by Flux `withgradient` match what I have calculated?

50 Views Asked by At

I'm trying to train a simple CNN with Flux and running into a weird issue...during training the loss appears to go down (indicating that it's working) but despite what the loss curve suggested the "trained" model output was very bad, and when I calculated the loss by hand I noticed that it differed from what the training indicated it should be (it was acting like it hadn't been trained at all).

I then started calculating the loss returned inside the gradient vs. outside, and after a lot of digging I think the problem is related to the BatchNorm layer. Consider the following minimum example:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

All of the losses above will be different, sometimes pretty drastically (when running just now I got 22.6, 30.7, and 23.0), when I think they should all be the same?

Interestingly if I remove the BatchNorm layer, the outputs are all the same, i.e. running:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

Produces the same number for each loss calculation.

Why does including the BatchNorm layer change the value of the loss like this?

My (limited) understanding was that this was just supposed to normalize the input values, which I understand could affect the loss between the unormalized and normalized case, but I don't understand why it would produce different values of the loss for the same input values on the same model without any of the parameters of said model being updated?

1

There are 1 best solutions below

1
max xilian On BEST ANSWER

Look at the documentation of BatchNorm

BatchNorm(channels::Integer, λ=identity;
            initβ=zeros32, initγ=ones32,
            affine=true, track_stats=true, active=nothing,
            eps=1f-5, momentum= 0.1f0)

  Batch Normalization (https://arxiv.org/abs/1502.03167) layer. channels should
  be the size of the channel dimension in your data (see below).

  Given an array with N dimensions, call the N-1th the channel dimension. For a
  batch of feature vectors this is just the data dimension, for WHCN images it's
  the usual channel dimension.

  BatchNorm computes the mean and variance for each D_1×...×D_{N-2}×1×D_N input
  slice and normalises the input accordingly.

  If affine=true, it also applies a shift and a rescale to the input through to
  learnable per-channel bias β and scale γ parameters.

  After normalisation, elementwise activation λ is applied.

  If track_stats=true, accumulates mean and var statistics in training phase that
  will be used to renormalize the input in test phase.

  Use testmode! during inference.

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  julia> using Statistics
  
  julia> xs = rand(3, 3, 3, 2);  # a batch of 2 images, each having 3 channels
  
  julia> m = BatchNorm(3);
  
  julia> Flux.trainmode!(m);
  
  julia> isapprox(std(m(xs)), 1, atol=0.1) && std(xs) != std(m(xs))
  true

The key bit here is that per default track_stats=true. This leads to the changing inputs. If you don't want to have this behaviour, initialise your model with

m = Chain(BatchNorm(1, track_state=false),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)

and you'll get identical outputs as in your second example.

The BatchNorm is initialised with zero mean and unit std, and your input data isn't, that's why you'll get the changing output even with repeated identical input in the case that track_state=true, as far as I can see it (quickly).