I tried to find documentation or examples in order to be able to save the predictions generated by the TFX BulkInferrer component into 1) BigQuery, 2) dataframe or 3) at least reading them. I couldn't find any and now I am stuck because I cannot use the results. I had the impression that saving the results from that component should be a straight forward task but it seems there are no examples to be found. The component is generating a "prediction_logs-00000-of-00001.gz" file but I was not able to read from it. Could someone help? I use TFX 1.12.0 pipeline with Kubeflow orchestrator in GCP.
[update]
I found and tried this utility: https://github.com/tensorflow/tfx/blob/master/tfx/components/bulk_inferrer/prediction_to_example_utils.py
bulk_inferrer = BulkInferrer(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
data_spec=bulk_inferrer_pb2.DataSpec(),
model_spec=bulk_inferrer_pb2.ModelSpec(),
)
context.run(bulk_inferrer, enable_cache=False)
predictions = bulk_inferrer.outputs['inference_result'].get()[0]
from prediction_to_example_utils import convert
convert(predictions, output_example_spec = bulk_inferrer_pb2.OutputExampleSpec())
but I get error:
----> 5 convert(predictions, output_example_spec = bulk_inferrer_pb2.OutputExampleSpec())
6
7
~/TFX/NEW_TFX_2/prediction_to_example_utils.py in convert(prediction_log, output_example_spec)
47 """
48 specs = output_example_spec.output_columns_spec
---> 49 if prediction_log.HasField('multi_inference_log'):
50 example, output_features = _parse_multi_inference_log(
51 prediction_log.multi_inference_log, output_example_spec)
~/tfx_env/lib/python3.7/site-packages/tfx/types/artifact.py in __getattr__(self, name)
252 raise AttributeError()
253 if name not in self._artifact_type.properties:
--> 254 raise AttributeError('Artifact has no property %r.' % name)
255 property_mlmd_type = self._artifact_type.properties[name]
256 if property_mlmd_type == metadata_store_pb2.STRING:
AttributeError: Artifact has no property 'HasField'.
[Update #2] I tried to read using tf.Record but it seems that the inference output is empty (or I did not read it the right way). Here is the code:
from tfx.proto import example_gen_pb2
example_gen_infer = BigQueryExampleGen(query=QUERY,
output_config=example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[
example_gen_pb2.SplitConfig.Split(name=‘train’, hash_buckets=1)
]
)
))
context.run(example_gen_infer,
beam_pipeline_args=[
'--project', GOOGLE_CLOUD_PROJECT,
'--temp_location', 'gs://' + GCS_BUCKET_NAME + '/tmp',
'--region', GOOGLE_CLOUD_REGION
]
)
bulk_inferrer = BulkInferrer(
examples=example_gen_infer.outputs['examples'],
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
data_spec=bulk_inferrer_pb2.DataSpec(),
model_spec=bulk_inferrer_pb2.ModelSpec(),
)
context.run(bulk_inferrer, enable_cache=False)
bulk_artifact_dir = bulk_inferrer.outputs[‘inference_result’].get()[0].uri
pp.pprint(os.listdir(bulk_artifact_dir))
prediction_uri = os.path.join(bulk_artifact_dir, ‘prediction_logs-00000-of-00001.gz’)
pp.pprint(os.listdir(bulk_artifact_dir))
pp.pprint(prediction_uri)
tfrecord_filenames = prediction_uri
Create a TFRecordDataset to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=“GZIP”)
Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pp.pprint(example)
==> the infer printout is empty but the saved file in the archive it does have content in binary format.