How to update tf.variable based on the samples within a batch, rather than the loss function

29 Views Asked by At

I wrote a simple clustering algorithm in TensorFlow Estimator. In the algorithm, the cluster centers scene_cluster are updated based on the samples every batch. The updated cluster centers are only used for selecting samples to which cluster it belongs, therefore, the tf.variable for cluster centers are independent of the loss function. However, the cluster center variables scene_cluster didn’t update during the training process and remain unchanged.

Here is the code:

    with tf.variable_scope("cluster", reuse=tf.AUTO_REUSE):
        # Initialize the cluster centers
        scene_cluster = tf.get_variable("scene_cluster", [cluster_num, input.get_shape().as_list()[-1]],
                                        initializer=tf.glorot_uniform_initializer())

        # compute the similarity between cluster centers and current batch of samples
        cluster_similarity = tf.tensordot(input, tf.transpose(scene_cluster), axes=1)  
        cluster_similarity = tf.nn.softmax(cluster_similarity, axis=-1)  

        # update cluster centers
        scene_cluster_update = tf.tensordot(tf.transpose(cluster_similarity), input, axes=2) 
        scene_cluster_op = tf.assign(scene_cluster, beta * scene_cluster + (1 - beta) * scene_cluster_update)
        
        # calculate the cluster to which each sample belongs based on the similarity
        cluster_candidates = tf.argmax(cluster_similarity, axis=-1) 

So my question is how to update a variable which is independent of loss function in Tensorflow Estimator?

0

There are 0 best solutions below