What is a model in Julia Flux 0.13 and higher?

108 Views Asked by At

(also posted in Julia discourse https://discourse.julialang.org/t/what-is-a-model-in-julia-flux-0-13-and-higher/100653)

I want to use Julia Flux for machine learning with custom models (not neural networks, so I won't be using/combining models provided by Flux). I want to do it with Flux because I want access to various advanced gradient decent algorithms.

To do so I intend to use the training API (https://fluxml.ai/Flux.jl/stable/training/reference/#Training-API-Reference)

The problem is that the documentation is not very detailed. For example there are functions like

Flux.setup(rule, model)

and

Flux.train!(loss, model, data, opt_state)

however nowhere in the API there is a description of what the model is and what form should it take.

As a test problem, consider matrix factorization. That is, given matrix A

using Random

dim = 2

A = rand(dim, dim)

find such x and y that

x * y' ≈ A

I would guess that the model should be defined as model(x, y) = x * y', but then Flux.setup(AdaGrad(), model) produces a warning

Warning: setup found no trainable parameters in this model
1

There are 1 best solutions below

2
Nils Gudat On

The page you link to has an example of what a model is in the documentation of Flux.Train.setup:

julia> model = Dense(2=>1, leakyrelu; init=ones);

julia> opt_state = Flux.setup(Momentum(0.1), model)  # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())

julia> x1, y1 = [0.2, -0.3], [0.4];  # use the same data for two steps:

julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
         sum(abs.(m(x) .- y)) * 100
       end

It might be easier to look at the "Fitting a straight line" part of the docs:

https://fluxml.ai/Flux.jl/stable/models/overview/