I wrote a custom metric in tensorflow to implement Balanced Accuracy, but it gives nan after the first training epoch. Here is my code:
class BalancedSparseCategoricalAccuracy(keras.metrics.Metric):
def __init__(self, name='balanced_sparse_categorical_accuracy', **kwargs):
super(BalancedSparseCategoricalAccuracy, self).__init__(**kwargs)
self.correct_count = tf.Variable([0,0,0,0,0,0,0])
self.total_count = tf.Variable([0,0,0,0,0,0,0])
self.accuracy = 0.0
self.quotient = tf.Variable([0.,0.,0.,0.,0.,0.,0.])
def reset_state(self):
self.correct_count = tf.Variable([0,0,0,0,0,0,0])
self.total_count = tf.Variable([0,0,0,0,0,0,0])
self.quotient = tf.Variable([0.,0.,0.,0.,0.,0.,0.])
self.accuracy = 0.0
def update_state(self, y_true, y_pred, sample_weight=None):
pred_idx = tf.argmax(y_pred, axis=1, output_type=tf.int32)
y_true=tf.reshape(y_true, tf.shape(pred_idx))
y_true=tf.cast(y_true,tf.int32)
def count_all(y_true):
return tf.cast([ tf.reduce_sum(tf.cast(tf.equal(y_true,i),tf.float32)) for i in range(0,7)],tf.int32)
def count_match(y_true, y_pred):
return tf.cast([ tf.reduce_sum(tf.cast(tf.equal(y_true,i)&tf.equal(y_true,pred_idx),tf.float32)) for i in range(0,7) ],tf.int32)
total_counts = count_all(y_true)
tf.compat.v1.assign_add(self.total_count, total_counts)
correct_counts = count_match(y_true, y_pred)
tf.compat.v1.assign_add(self.correct_count, correct_counts)
self.quotient = tf.cast(self.correct_count/tf.maximum(self.total_count,tf.constant(1)), tf.float32)
print(self.quotient)
#self.accuracy = tf.reduce_mean(self.correct_count/tf.maximum(self.total_count,1))
#return self.accuracy
def result(self):
#self.accuracy = tf.reduce_mean(self.correct_count/tf.maximum(self.total_count,1))
self.accuracy = tf.reduce_sum(tf.reduce_sum(self.quotient) / tf.cast(tf.math.count_nonzero(self.total_count),tf.float32))
return self.accuracy
Epoch 1/40
31/31 [==============================] - 21s 354ms/step - loss: 13.7365 - balanced_sparse_categorical_accuracy_1: 0.4473 - accuracy: 0.5518 - val_loss: 17.3646 - val_balanced_sparse_categorical_accuracy_1: 0.3649 - val_accuracy: 0.4486
Epoch 2/40
31/31 [==============================] - 16s 333ms/step - loss: 13.7993 - balanced_sparse_categorical_accuracy_1: nan - accuracy: 0.5590 - val_loss: 16.8144 - val_balanced_sparse_categorical_accuracy_1: nan - val_accuracy: 0.4486
Epoch 3/40
31/31 [==============================] - 18s 361ms/step - loss: 13.2094 - balanced_sparse_categorical_accuracy_1: nan - accuracy: 0.5745 - val_loss: 16.6154 - val_balanced_sparse_categorical_accuracy_1: nan - val_accuracy: 0.4579
I tried to check if there was a division by zero, but that was not the case. I also checked on model predictions, and it gave the correct result:
predictions=model.predict(datav)
grd_truth=[]
for batch in datav:
val=np.asarray(batch[1])
grd_truth.extend(val)
acc = BalancedSparseCategoricalAccuracy()
acc.update_state(grd_truth, predictions)
acc.result()
Output:
<tf.Tensor: shape=(), dtype=float32, numpy=0.23611432>