How to guarantee convergence when training a neural differential equation?

118 Views Asked by At

I'm currently working through the SciML tutorials workshop exercises for the Julia language (https://tutorials.sciml.ai/html/exercises/01-workshop_exercises.html). Specifically, I'm stuck on exercise 6 part 3, which involves training a neural network to approximate the system of equations

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

The goal is to replace the equation for du[2] with a neural network: du[2] = NN(u, p) where NN is a neural net with parameters p and inputs u.

I have a set of sample data that the network should try to match. The loss function is the squared difference between the network model's output and that sample data.

I defined my network with NN = Chain(Dense(2,30), Dense(30, 1)). I can get Flux.train! to run, but the problem is that sometimes the initial parameters for the neural network result in a loss on the order of 10^20 and so training never converges. My best try got the loss down from about 2000 initially to about 20 using the ADAM optimizer over about 1000 iterations, but I can't seem to do any better.

How can I make sure my network is consistently trainable, and is there a way to get better convergence?

1

There are 1 best solutions below

0
On

How can I make sure my network is consistently trainable, and is there a way to get better convergence?

See the FAQ page on techniques for improving convergence. In a nutshell, the single shooting approach of most ML papers is very unstable and does not work on most practical problems, but there are a litany of techniques to help out. One of the best ones is multiple shooting, which optimizes only short bursts (in parallel) along the time series.

But training on a small interval and growing the interval works, also using more stable optimizers (BFGS) can work. You can also weigh the loss function so that earlier times mean more. Lastly, you can minibatch in a way similar to multiple shooting, i.e. start from a data point and only solve to the next (in fact, if you actually look at the original neural ODE paper NumPy code, they do not do the algorithm as explained but instead do this form of sampling to stabilize the spiral ODE training).