I'm writing a data import pipeline to train a model in tensorflow which requires that examples from three different datasets (consisting of image and label) are yielded together but kept separate, in a fashion like: ((img_ds1, label_ds1), (img_ds2, label_ds2), (img_ds3, label_ds3)). Data are stored in tfRecords and are currently imported with the following function, which should be pretty straightforward:

def loadDataset(root, ds_type, SEED=1994, shuffle_size=1000, batch_size=32, f_list=[], f_kwargs=[]):
    
# extract list of tfrecord files and import them
    file_list = [os.path.join(root, f) for f in os.listdir(root)]
    ds = TFRecordDataset(file_list)
    print(f'parsing {ds_type} dataset')
    if ds_type == "a":
        ds = ds.map(parseExampleA)
    elif ds_type == "b":
        ds = ds.map(parseExampleB)
    elif ds_type == "c":
        ds = ds.map(parseExampleC)

# applies preprocessing functions
    for f, kwargs in zip(f_list, f_kwargs):
        ds = ds.map(lambda *x: (f(x[0], **kwargs), x[1]))
    
    return ds.shuffle(shuffle_size, SEED).batch(batch_size).repeat()

In the script, three datasets are assigned like this:

from myLibrary import loadDataset

preproc_fun = [fun1, fun2, fun3]
preproc_kwargs = [{"k1":var1, "k2":var2}, {etc.}, {etc.}]
ds_a = loadDataset(root, 'a', f_list=preproc_fun, f_kwargs=preproc_kwargs)
ds_b = loadDataset(root, 'b', f_list=preproc_fun, f_kwargs=preproc_kwargs)
ds_c = loadDataset(root, 'c', f_list=preproc_fun, f_kwargs=preproc_kwargs)

Now, what happens is that: if i try to do some ds.take(n) inside loadDatases function, it runs and supposedly does what i intend. However, whenever i try to do that outside the function i get the error in the title, which frankly i cannot understand because .take(n) should return another dataset and anyway i used it plenty of times like that without any sort of similar problem. The full error output is the following:

Traceback (most recent call last):
  File "/Users/leo/Documents/repos/ChestXRayEnsembling/ChestXRAY/multidecoder/train_script.py", line 156, in <module>
    x=tf.data.Dataset.zip((chex_ds.take(2).repeat(), nih_ds.take(2).repeat(), vin_ds.take(2).repeat())),
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1697, in take
    return TakeDataset(self, count, name=name)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5145, in __init__
    variant_tensor = gen_dataset_ops.take_dataset(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 7711, in take_dataset
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/op_def_library.py", line 777, in _apply_op_helper
    _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/op_def_library.py", line 550, in _ExtractInputsAndAttrs
    values = ops.convert_to_tensor(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/profiler/trace.py", line 183, in wrapped
    return func(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf-metal/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 1586, in convert_to_tensor
    raise RuntimeError(
RuntimeError: input_dataset: Attempting to capture an EagerTensor without building a function.

I'm working in a conda environment with python 3.10.8 and tensorflow-metal 0.7, on an M1 Mac Pro. The reason why i'm taking a small portion of my ds is that i want to test that everything works fine before actually training my model.

0

There are 0 best solutions below