Context
Using TF 1.15 with tf.Estimator
interface to train and evaluate models. Trying to write a custom TF metric, using tf.keras.metric.Metric
for that.
Problem
I wrote a custom metric and included that in the eval_metrics_ops
(example below). If I define an estimator with the metrics, I get the following error.
ValueError: Please call update_state(...) on the "<metric_name>" metric
The wording of the error looks clear (I have to call update_state()
) but I am not sure where do I call update_state()
on the metric (not sure if I even should call). Not a minimal example, but this is a demo metric I wrote.
class MyMetric(tf.keras.metrics.Metric):
def __init__(self, name="my_metric", **kwargs):
super(MyMetric, self).__init__(name=name, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
self.true_samples = tf.reduce_sum(y_true)
def result(self):
return self.true_samples
Creating a dict
where metric name is the key and the metric instance is the value. This is where it mentions how to create a dict
for eval_metrics_ops
.
metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode, model.loss, eval_metric_ops=metrics_ops)
Any idea how I can get rid of that error?