predicting ODE parameters with DiffEqFlux

707 Views Asked by At

I'm trying to build a neural network that will take in the solutions to a system of ODE's and predict the parameters of the system. I'm using Julia and in particular, the DiffEqFlux package. The structure of a network is a few simple Dense layers chained together that predict some intermediate parameters (in this case, some chemical reaction free energies), which then feed into some deterministic (non-trained) layers that convert those parameters into the ones that go into the system of equations (in this case, reaction rate constants). I've tried two different approaches from here:

  1. Chain the ODE solve directly on as the last layer of the network. In this case, the loss function is just comparing the inputs to the outputs.

  2. Have the ODE solve in the loss function, so the network output is just the parameters.

However, in neither case can I get Flux.train! to actually run.

A silly little example for the first option that gives the same error I'm getting (I've tried to keep as many things parallel to my actual case as possible, i.e. the solver, etc., although I did omit the intermediate deterministic layers since they don't seem to make a difference) is shown below.

using Flux, DiffEqFlux, DifferentialEquations

# let's use Chris' favorite example, Lotka-Volterra
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)

# generate a couple sets of solutions to train on
training_params = [[1.5,1.0,3.0,1.0], [1.4,1.1,3.1,0.9]]
training_sols = [solve(ODEProblem(lotka_volterra, u0, tspan, tp)).u[end] for tp in training_params]

model = Chain(Dense(2,3), Dense(3,4), p -> diffeq_adjoint(p, ODEProblem(lotka_volterra, u0, tspan, p), Rodas4())[:,end])

# in this case we just want outputs to match inputs
# (actual parameters we're after are outputs of next-to-last layer)
training_data = zip(training_sols, training_sols)

# mean squared error loss
loss(x,y) = Flux.mse(model(x), y)

p = Flux.params(model[1:2])

Flux.train!(loss, p, training_data, ADAM(0.001))
# gives TypeError: in typeassert, expected Float64, got ForwardDiff.Dual{Nothing, Float64, 8}

I've tried all three solver layers, diffeq_adjoint, diffeq_rd, and diffeq_fd, none of which work, but all of which give different errors that I'm having trouble parsing.

For the other option (which I'd actually prefer, but either way would work), just replace the model and loss function definitions as:

model = Chain(Dense(2,3), Dense(3,4))

function loss(x,y)
   p = model(x)
   sol = diffeq_adjoint(p, ODEProblem(lotka_volterra, u0, tspan, p), Rodas4())[:,end]
   Flux.mse(sol, y)
end

The same error is thrown as above.

I've been hacking at this for over a week now and am completely stumped; any ideas?

1

There are 1 best solutions below

3
On

You're running into https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/31, i.e. forward-mode AD for the Jacobian doesn't play nice with Flux.jl right now. To get around this, use Rodas4(autodiff=false) instead.