Ways to Improve Universal Differential Equation Training with sciml_train

287 Views Asked by At

About a month ago I asked a question about strategies for better convergence when training a neural differential equation. I've since gotten that example to work using the advice I was given, but when I applied what the same advice to a more difficult model, I got stuck again. All of my code is in Julia, primarily making use of the DiffEqFlux library. In effort to keep this post as brief as possible, I won't share all of my code for everything I've tried, but if anyone wants access to it to troubleshoot I can provide it.

What I'm Trying to Do

The data I'm trying to learn comes from an SIRx model:

function SIRx!(du, u, p, t)
    β, μ, γ, a, b = Float32.([280, 1/50, 365/22, 100, 0.05])
    S, I, x = u
    du[1] = μ*(1-x) - β*S*I - μ*S
    du[2] = β*S*I - (μ+γ)*I
    du[3] = a*I - b*x
    nothing
end;

The initial condition I used was u0 = Float32.([0.062047128, 1.3126149f-7, 0.9486445]);. I generated data from t=0 to 25, sampled every 0.02 (in training, I only use every 20 points or so for speed, and using more doesn't improve results). The data looks like this: Training Data

The UDE I'm training is

function SIRx_ude!(du, u, p, t)
    μ, γ = Float32.([1/50, 365/22])
    S,I,x = u
    du[1] = μ*(1-x) - μ*S + ann_dS(u, @view p[1:lenS])[1]
    du[2] = -(μ+γ)*I + ann_dI(u, @view p[lenS+1:lenS+lenI])[1]
    du[3] = ann_dx(u, @view p[lenI+1:end])[1]
    nothing
end;

Each of the neural networks (ann_dS, ann_dI, ann_dx) are defined using FastChain(FastDense(3, 20, tanh), FastDense(20, 1)). I tried using a single neural network with 3 inputs and 3 outputs, but it was slower and didn't perform any better. I also tried normalizing inputs to the network first, but it doesn't make a significant difference outside of slowing things down.

What I've Tried

  • Single shooting The network just fits a line through the middle of the data. This happens even when I weight the earlier datapoints more in the loss function. Single-shot Training
  • Multiple Shooting The best result I had was with multiple shooting. As seen here, it's not simply fitting a straight line, but it's not exactly fitting the data eitherMultiple Shooting Result. I've tried continuity terms ranging from 0.1 to 100 and group sizes from 3 to 30 and it doesn't make a significant difference.
  • Various Other Strategies I've also tried iteratively growing the fit, 2-stage training with a collocation, and mini-batching as outlined here: https://diffeqflux.sciml.ai/dev/examples/local_minima, https://diffeqflux.sciml.ai/dev/examples/collocation/, https://diffeqflux.sciml.ai/dev/examples/minibatch/. Iteratively growing the fit works well the first couple of iterations, but as the length increases it goes back to fitting a straight line again. 2-stage collocation training works really well for stage 1, but it doesn't actually improve performance on the second stage (I've tried both single and multiple shooting for the second stage). Finally, mini-batching worked about as well as single-shooting (which is to say not very well) but much more quickly.

My Question

In summary, I have no idea what to try. There are so many strategies, each with so many parameters that can be tweaked. I need a way to diagnose the problem more precisely so I can better decide how to proceed. If anyone has experience with this sort of problem, I'd appreciate any advice or guidance I can get.

1

There are 1 best solutions below

1
On

This isn't a great SO question because it's more exploratory. Did you lower your ODE tolerances? That would improve your gradient calculation which could help. What activation function are you using? I would use something like softplus instead of tanh so that you don't have the saturating behavior. Did you scale the eigenvalues and take into account the issues explored in the stiff neural ODE paper? Larger neural networks? Different learning rates? ADAM? Etc.

This is much better suited for a forum for discussion like the JuliaLang Discourse. We can continue there since walking through this will not be fruitful without some back and forth.