Acessing Current Batch Size (on per batches per epoch)

27 Views Asked by At

I'm working on a keras layer based on the _Merge Layer and the idea is to perform a comparasion element-wise between 3 variables and change the value of x based on a certain conditon function of a_11 and a_22. So far I've come this far using tf.tensor_scatter_nd_min/max to search the multiple scenarios where the value of x must be limited and force it.

My issue lies in the batch_size variable which the way I set it up (very poorly) fails when the last batch is reached as it has a slighly smaller size compared to the others but the way the code is structured it just receives the initial batch_size...

So what I am really asking is how can I acess the "the batch size of the current step" and if there is a better way to tackle this problem without having to pass all indices to tf.tensor_scatter_nd_min/max?

P.S. I did try with Sparse Tensors but had no luck either

from keras.layers.merging.base_merge import _Merge

class _realize_12(_Merge):    
    def __init__(self, BATCH):
        super(_realize_12, self).__init__()
        self.batch_size = BATCH
    
    def _merge_function(self, inputs):        
        a_11 = inputs[0]
        x = inputs[1]
        a_22 = inputs[2]
        
        limit = tf.math.sqrt((tf.math.abs(a_11)+1/3)*(tf.math.abs(a_22)+1/3)) 
        
        _all =  [[i] for i in range(self.batch_size)]
                              
        return tf.tensor_scatter_nd_min(tf.tensor_scatter_nd_max(x, _all, -limit), _all, limit, name='constrain_a_12')
0

There are 0 best solutions below