I am trying to make a model that takes two different GraphTensors as input. Unfortunately, all the documentation is either about single GraphTensors or datasets containing one type of GraphTensor. My datasets are prepared as follows:
# This is how TFRecords are created (each graph type for train or val is written separately)
with tf.io.TFRecordWriter(api_train_path) as writer:
for index in range(train_len):
api_graph, sol_graph = GenTwoPairedGraphs(chemicals_df, dataset_df, atoms_df, bonds_df, index)
example = tfgnn.write_example(api_graph)
writer.write(example.SerializeToString())
## It would be best to combine everything here,, but I don't know how to decode it later.
# Decode functions
def decode_fn_api(record_bytes):
graph = tfgnn.parse_single_example(
api_graph_tensor_spec, record_bytes, validate=True)
# extract label from context and remove from input graph
context_features = graph.context.get_features_dict()
label = context_features.pop('label')
api_graph = graph.replace_features(context=context_features)
return api_graph, label
def decode_fn_sol(record_bytes):
sol_graph = tfgnn.parse_single_example(
solvent_graph_tensor_spec, record_bytes, validate=True)
return sol_graph
def ImportGraphDataset():
api_train_path = "TFRecords/api_train_dataset.tfrecords"
api_val_path = "TFRecords/api_val_dataset.tfrecords"
sol_train_path = "TFRecords/sol_train_dataset.tfrecords"
sol_val_path = "TFRecords/sol_val_dataset.tfrecords"
api_train_ds = tf.data.TFRecordDataset([api_train_path]).map(decode_fn_api)
api_val_ds = tf.data.TFRecordDataset([api_val_path]).map(decode_fn_api)
sol_train_ds = tf.data.TFRecordDataset([sol_train_path]).map(decode_fn_sol)
sol_val_ds = tf.data.TFRecordDataset([sol_val_path]).map(decode_fn_sol)
return api_train_ds, api_val_ds, sol_train_ds, sol_val_ds
api_train_ds, api_val_ds, sol_train_ds, sol_val_ds = ImportGraphDataset()
My model is bulid as follows:
# One part of the model with api_graph as input
api_input_graph = tf.keras.layers.Input(type_spec=api_graph_tensor_spec, name="api_graph")
api_graph = api_input_graph.merge_batch_to_components()
(...)
api_readout_features = tfgnn.keras.layers.Pool(
tfgnn.CONTEXT, "mean", node_set_name="atoms")(api_graph)
# Second part of the model with sol_graph as input
solvent_input_graph = tf.keras.layers.Input(type_spec=solvent_graph_tensor_spec, name="solvent_graph")
solvent_graph = solvent_input_graph.merge_batch_to_components()
(...)
solvent_readout_features = tfgnn.keras.layers.Pool(
tfgnn.CONTEXT, "mean", node_set_name="atoms")(solvent_graph)
# Final layers that calculate output
feat = tf.concat([api_readout_features, solvent_readout_features], axis=1)
final_dense = tf.keras.layers.Dense(32, activation="relu")(feat)
logits = tf.keras.layers.Dense(1, name = "label")(final_dense)
tf.keras.Model(inputs=[api_input_graph, solvent_input_graph], outputs = [logits])
The problem is that model.fit does not accept more than one dataset as input, and I don't know how to place and decode a TFRecord with more than one type of GraphTensor.
So far I have tried using tf.data.Dataset.zip (input_ds = tf.data.Dataset.zip({"api_graph": api_train_ds, "solvent_graph": sol_train_ds}, label) when each set was separately, but I have not been able to run the model this way.