How to create a custom conditional activation function

63 Views Asked by At

I want to create custom activation function in TF2. The math is like this:

def sqrt_activation(x):
    if x >= 0:
        return tf.math.sqrt(x)
    else:
        return -tf.math.sqrt(-x)

The problem is that I can't compare x with 0 since x is a tensor. How to achieve this functionality?

2

There are 2 best solutions below

0
Vijay Mariappan On BEST ANSWER

You can skip the comparison by doing,

def sqrt_activation(x):
    return tf.math.sign(x)*tf.math.sqrt(tf.abs(x))
0
ahmet hamza emra On

YOu need to use tf backend functions and convert your code as follows:

import tensorflow as tf
@tf.function
def sqrt_activation(x):
    zeros = tf.zeros_like(x)
    pos = tf.where(x >= 0, tf.math.sqrt(x), zeros)
    neg = tf.where(x < 0, -tf.math.sqrt(-x), zeros)
    return pos + neg

note that this function check all tensor to meet on those conditions ergo returning the pos + neg line