Tensorflow Keras custom metrics error on update_state()

954 Views Asked by At

Context

Using TF 1.15 with tf.Estimator interface to train and evaluate models. Trying to write a custom TF metric, using tf.keras.metric.Metric for that.

Problem

I wrote a custom metric and included that in the eval_metrics_ops (example below). If I define an estimator with the metrics, I get the following error.

ValueError: Please call update_state(...) on the "<metric_name>" metric 

The wording of the error looks clear (I have to call update_state()) but I am not sure where do I call update_state() on the metric (not sure if I even should call). Not a minimal example, but this is a demo metric I wrote.

class MyMetric(tf.keras.metrics.Metric):
    def __init__(self, name="my_metric", **kwargs):
        super(MyMetric, self).__init__(name=name, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.true_samples = tf.reduce_sum(y_true)
    
    def result(self):
        return self.true_samples

Creating a dict where metric name is the key and the metric instance is the value. This is where it mentions how to create a dict for eval_metrics_ops.

metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode, model.loss, eval_metric_ops=metrics_ops)

Any idea how I can get rid of that error?

0

There are 0 best solutions below