How to train Flux.jl to learn a sequence conditional to some initial "seeds"?

62 Views Asked by At

I am trying to write a RNN model that given an initial "seed" sequence, it reproduces the continuation of the sequence. In the code above dummy sequences are generated as function of these initial seed points and a RNN approach is attempted, but when I plot the generated sequences, they are very badly connected with the "true" ones and my model at best learns a sequence unconditional to the seed (i.e. the expected value of the unconditional sequence).

Setting the environment...

# Setting the environment...
cd(@__DIR__)    
using Pkg      
Pkg.activate(".")  
# Pkg.add(["Plots","Flux"])
# Pkg.resolve()   
# Pkg.instantiate()
using Random
Random.seed!(123)
using LinearAlgebra, Plots, Flux

Generating simulated data

The idea is to have a sequence that depends on the first 5 values. So the first 5 values are random, but the rest of the sequence depends deterministically to these first 5 values and the objective it to recreate this second part of the sequence knowing the first 5 parts.

nSeeds    = 5
seqLength = 5
nTrains   = 1000  
nVal      = 100
nTot = nTrains+nVal
makeSeeds(nSeeds) = 2 .* (rand(nSeeds) .- 0.5) # [-1,+1]
function makeSequence(seeds,seqLength)
  seq = Vector{Float32}(undef,seqLength+nSeeds) # Flux Works with Float32 for performance reasons
  [seq[i] = seeds[i] for i in 1:nSeeds]
  for i in nSeeds+1:(seqLength+nSeeds)
     seq[i] = seq[i-1] + (seeds[4]*0.5) # the only seed that matters is the 4th. Let's see if the RNN learn it !
  end
  return seq
  return seq[nSeeds+1:end]
end

x0   = [makeSeeds(nSeeds) for i in 1:nTot]
seqs = makeSequence.(x0,seqLength)
seqs_vectors = [[[e] for e in seq] for seq in seqs]
y    = [s[2:end] for s in seqs_vectors] # y here is the value of the sequence itself at next step

xtrain = seqs_vectors[1:nTrains]
xval   = seqs_vectors[nTrains+1:end]
ytrain = y[1:nTrains]
yval   = y[nTrains+1:end]

# Flux wants a vector of sequences of individual items, when these in turns are vectors
allData   = xtrain;
aSequence = allData[1]
anElement = aSequence[1]

Some utility functions

function predictSequence(m,seeds,seqLength)
    seq = Vector{Vector{Float32}}(undef,seqLength+length(seeds)-1)
    Flux.reset!(m) # Reset the state (not the weigtht!)
    [seq[i] = [convert(Float32, seeds[i])] for i in 1:nSeeds]
    [seq[i] = m(seq[i-1]) for i in nSeeds+1:nSeeds+seqLength-1]
    [s[1] for s in seq]
end

function myloss(x, y)
    Flux.reset!(m)                 # Reset the state (not the weigtht!)
    [m(x[i]) for i in 1:nSeeds-1]  # Ignores the output but updates the hidden states
    # y_i is x_(i+1), i.e. next element
    sum(Flux.mse(m(xi), yi) for (xi, yi) in zip(x[nSeeds:(end-1)], y[nSeeds:end]))
end
"""
   batchSequences(x,batchSize)

Transform a vector of sequences of individual elements represented as feature vectors to a vector of sequences of elements represented as features ×  batched record matrices
"""
function batchSequences(x,batchSize)
    x = copy(xtrain)
    batchSize = 3
    nRecords  = length(x)
    nItems    = length(x[1])
    nDims     = size(x[1][1],1) 
    nBatches  = Int(floor(nRecords/batchSize))

    emptyBatchedElement = Matrix{Float32}(undef,nDims,batchSize)
    emptySeq = [similar(emptyBatchedElement) for i in 1:nItems]
    outx = [similar(emptySeq) for i in 1:nBatches]
    for b in 1:nBatches
        xmin = (b-1)*batchSize + 1
        xmax = b*batchSize
        for e in 1:nItems
            outx[b][e] = hcat([x[i][e][:,1] for i in xmin:xmax]... )
        end
    end  
    return outx
end

Defining the model

m   = Chain(Dense(1,3),LSTM(3, 3), Dense(3, 5,relu),Dense(5,1))
ps  = params(m)
opt = Flux.ADAM()

Plotting a random sequence and its prediction from untrained model..

seq1True = makeSequence(x0[1],seqLength)
seq1Est0 = predictSequence(m,x0[1],seqLength)
plot(seq1True)
plot!(seq1Est0)

Actual training

trainMSE  = Float64[]
valMSE    = Float64[]
epochs    = 20 
batchSize = 16
for e in 1:epochs
    print("Epoch $e ")
    # Shuffling at each epoch
    ids = shuffle(1:length(xtrain))
    x0e      = x0[ids]
    xtraine  = xtrain[ids]
    ytraine  = ytrain[ids]

    xtraine =batchSequences(xtraine,batchSize)
    ytraine =batchSequences(ytraine,batchSize)
    trainxy = zip(xtraine,ytraine)

    # Actual training
    Flux.train!(myloss, ps, trainxy, opt)
    # Making prediction on the trained model and computing accuracies
    global trainMSE, valMSE
    ŷtrain  = [predictSequence(m,x0[i],seqLength) for i in 1:nTrains]
    ŷval    = [predictSequence(m,x0[i],seqLength) for i in (nTrains+1):nTot]
    ytrain  = [makeSequence(x0[i],seqLength) for i in  1:nTrains]
    yval    = [makeSequence(x0[i],seqLength) for i in  (nTrains+1):nTot]

    trainmse =  sum(norm(ŷtrain[i][nSeeds+1:end] - ytrain[i][nSeeds+1:end-1])^2 for i in 1:nTrains)/nTrains
    valmse   =  sum(norm(ŷval[i][nSeeds+1:end] - yval[i][nSeeds+1:end-1])^2 for i in 1:nVal)/nVal
    push!(trainMSE,trainmse)
    push!(valMSE,valmse)
    println("MEan Sq Error: $trainmse - $valmse")
end

Plotting some random sequences

for i = rand(1:nTot,5)
    trueseq = makeSequence(x0[i],seqLength)
    estseq  = predictSequence(m,x0[i],seqLength)
    seqPlot = plot(trueseq[1:end-1],label="true", title = "Seq $i")
    plot!(seqPlot, estseq, label="est")
    display(seqPlot)
end

Plotting the error

Strange, the validation error is always lower than the training error...

plot(trainMSE,label="Train MSE")
plot!(valMSE,label="Validation MSE")

The error changes depending of the parameter but got always stuck to some local minima, typically taking the expected value of the sequence unconditionally (i.e. horizontal line):

![image|499x333](upload://7rIfoaqq2nZeIUTeQ2IezTSqcS.png)

And some sequence true/predicted looks like:

enter image description here

enter image description here

enter image description here

(note that the estimate doesn't change. Sometimes I can make it change, but I always have outputs very far from the intended sequence)

I have tried other sequence structures, where the value doesn't depend on a "fixed" position (e.g. seq[i] = seq[i-1] + 0.2*seq[i-2] ) but I always got the same result.. the gradient may descend of a factor 10, but in practice the estimation remain "constant", independent from the initial seeds.

0

There are 0 best solutions below