Sampling from a custom dataset that is derived from tf.data.dataset in tensorflow

37 Views Asked by At

I'm new to tensorflow and trying to write a custom dataset class derived from tf.data.dataset like this example code:

import tensorflow as tf

class CustomDataset(tf.data.Dataset):

    def __init__(self, num_of_images: int):
        self.num_of_images = num_of_images
    
    def generator(self):
       return tf.ones(shape=(5, 5, 3), dtype=tf.int32)* self.num_of_images

    def __len__(self) -> int:
        return self.num_of_images
    
    def __call__(self):
        for _ in range(self.__len__()):
            yield self.generator()
    
    def _inputs(self):
        return ()

    def element_spec(self):
        return tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32)


if __name__ == "__main__":
    custom_dataset1 = CustomDataset(3)
    custom_dataset2 = CustomDataset(4)
    all_ds=[]
    all_ds.append(custom_dataset1)
    all_ds.append(custom_dataset2)
    sampled_ds = tf.data.Dataset.sample_from_datasets(all_ds, seed=1)
    #rest of the code
    
  1. I'm not quite sure what methods exactly to override from tf.data.Dataset
  2. When trying to sample it gives me this error:
 result = a.most_specific_common_supertype([b])
AttributeError: 'function' object has no attribute 'most_specific_common_supertype'

any help?

I've tried this, and it worked

    ds1 =  tf.data.Dataset.from_generator(generator= CustomDataset(1),
                                          output_signature=tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32))
    ds2 =  tf.data.Dataset.from_generator(generator= CustomDataset(2),
                                          output_signature=tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32))
    all_ds=[]
    all_ds.append(ds1)
    all_ds.append(ds2)
    sampled_ds = tf.data.Dataset.sample_from_datasets(all_ds, seed=1)

but is there any other way rather than using from_generator?

0

There are 0 best solutions below