I have a ganeral question about how Neural ODE Nets are trained in Julia. Are data points being sampled from the tspan on which the Nural ODE is defined and on them the parameter updates computed? In other words is there some shuffling and batching happening during training or is the loss computed over all data points in the tspan?
How does DiffEqFlux.sciml_train work for Neural ODEs in Julia?
867 Views Asked by SimonAda At
2
There are 2 best solutions below
3

parameters are optimized according to minimizing the loss function. So it's up to you to define how the sampling occurs in the loss function. Typically one may be comparing the output to discrete data points, in whcih case those become your discrete points.
But neuralODE isn't handling this--- you are. It's the loss function
I found an answer about what Julia is doing here: https://github.com/JuliaDiffEq/DiffEqFlux.jl/blob/master/src/train.jl .
I think that to do batching one needs to sample datapoints and then run Flux.train on them in a loop, giving as input the batch data points.