I'm working on a keras layer based on the _Merge Layer and the idea is to perform a comparasion element-wise between 3 variables and change the value of x based on a certain conditon function of a_11 and a_22. So far I've come this far using tf.tensor_scatter_nd_min/max to search the multiple scenarios where the value of x must be limited and force it.
My issue lies in the batch_size variable which the way I set it up (very poorly) fails when the last batch is reached as it has a slighly smaller size compared to the others but the way the code is structured it just receives the initial batch_size...
So what I am really asking is how can I acess the "the batch size of the current step" and if there is a better way to tackle this problem without having to pass all indices to tf.tensor_scatter_nd_min/max?
P.S. I did try with Sparse Tensors but had no luck either
from keras.layers.merging.base_merge import _Merge
class _realize_12(_Merge):
def __init__(self, BATCH):
super(_realize_12, self).__init__()
self.batch_size = BATCH
def _merge_function(self, inputs):
a_11 = inputs[0]
x = inputs[1]
a_22 = inputs[2]
limit = tf.math.sqrt((tf.math.abs(a_11)+1/3)*(tf.math.abs(a_22)+1/3))
_all = [[i] for i in range(self.batch_size)]
return tf.tensor_scatter_nd_min(tf.tensor_scatter_nd_max(x, _all, -limit), _all, limit, name='constrain_a_12')