I have a dataset with input images x_train, an objective value corresponding to each image y_train, and the set of derivatives of y_train wrt x_train, dc_train. I would like to train a CNN using this data, where I find the first derivative of output wrt input during training and use this in the loss function along with the usual loss.
I did this in the following way using ForwardDiff
using Flux
using ForwardDiff
x_train = Float32.(rand(10,30,1,50))
norm_y_train = Float32.(rand(1,50))
dc_train = Float32.(rand(10,30,50))
mean_y_train = 300.0
sd_y_train = 1000.0
model = Chain(
Conv((3, 3), 1=>6, pad=(1,1), relu),
MaxPool((3,3)),
Conv((3, 3), 6=>16, pad=(1,1), relu),
MaxPool((3,3)),
Flux.flatten,
Dense(48 => 1)
)
loader = Flux.DataLoader((x_train, norm_y_train, dc_train))
opt = Flux.setup(Flux.RAdam(0.01), model)
struct Model_struct{MT,T}
model::MT
sd::T
mean::T
end
function (ma::Model_struct)(x::Array)
return (ma.model(x).*ma.sd .+ ma.mean)[]
end
function my_loss(model, x, y, dc)
nn = Model_struct(model, sd_y_train, mean_y_train)
dc_hat = ForwardDiff.gradient(nn, x)
y_hat = model(x)
return Flux.mse(y_hat, y) + (0.5 * Flux.mse(dc_hat, dc))
end
epochs = []
train_loss = []
for epoch in 1:100
for (x, y, dc) in loader
val, grads = Flux.withgradient(model) do m
my_loss(m, x, y, dc)
end
Flux.update!(opt, model, grads[1])
end
push!(train_loss, Flux.mse(model(x_train), norm_y_train))
push!(epochs, epoch)
end
I have used random initiallizations for representation in the above example. Also y_train is normalized for training with mean mean_y_train and standard deviation sd_y_train.
This implementation gives me the below warning
┌ Warning: ForwardDiff.gradient(f, x) within Zygote cannot track gradients with respect to f,
│ and f appears to be a closure, or a struct with fields (according to issingletontype(typeof(f))).
│ typeof(f) = Model_struct{Chain{Tuple{Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Float64}
└ @ Zygote C:\Users\User\.julia\packages\Zygote\SuKWp\src\lib\forward.jl:142
I’m not able to figure out why this warning appears.
Is there a more elegant way to implement this?