How do I correctly define a custom STE gradient in Flux?

125 Views Asked by At

I am trying to write a custom STE gradient using Flux. The activation is basically just the sign() function, and its gradient is the incoming gradient as is iff its absolute value is <=1, and cancelled other wise. The implementation I currently have does not seem to work corrctly

binarize(x) = x>=0 ? true : false

binarize(x::Flux.Tracker.TrackedReal) = Flux.Tracker.track(binarize, x)

@grad function binarize(x)
    return binarize.(Flux.Tracker.data(x)), Δ -> (abs(x) <= 1 ? x : 0, )
      0), )
end

So for a random 5x1 matrix a I get:

>> a= param(randn(5))
>> Tracked 5-element Array{Float64,1}:
 -0.3605564089879154
 -0.7853512499733902
  0.8102988051980005
 -0.9715952052917924
 -1.276343849200165
>> c= binarize.(a)
>> 5-element BitArray{1}:
 false
 false
 true
 false
 false
>> Tracker.back!(c, [1,1,1,1,1])
>> a.grad
5-element Array{Float64,1}:
0.0
0.0
0.0
0.0
0.0

I would expect the gradient of a to be like a, except for the last element which would be 0.

What am I doing wrong?

0

There are 0 best solutions below