How do I properly convert tf.compat.v1.nn.ctc_loss to tf.nn.ctc_loss? (TF2)

157 Views Asked by At

I am trying to convert parts of my code from using tf.compat.v1 to pure tf functions. Here is the previous working function, that gives me good performance.

self.loss = tf.reduce_mean(
            input_tensor=tf.compat.v1.nn.ctc_loss(
                labels=self.gt_texts,
                inputs=self.ctc_in_3d_tbc,
                sequence_length=self.seq_len,
                ctc_merge_repeated=True,
            )
        )

This is my attempt to rewrite the function using tf.nn.ctc_loss.

self.loss = tf.reduce_mean(
            input_tensor=tf.nn.ctc_loss(
                labels=self.gt_texts,  # sparse tensor
                logits=self.ctc_in_3d_tbc,  
                label_length=None,  
                logit_length=self.seq_len,  
                blank_index=-1,
            ))

After changing only this, performance degraded considerably (63% accuracy to around 43% accuracy). I am wondering what in the conversion I am doing wrong. Using a blank_index=0 does not work.

0

There are 0 best solutions below