Tensorflow access tensor.numpy() in .map function but using py_function slows down iterator generation

376 Views Asked by At

I want to one hot encoder a tensor with my own one hot encoder. For this, I have to call tf.keras.backend.get_value() in .map which is only possible when using tf.py_function:

def one_hot_encode(categories,input):
  encoded_input = []
  data = tf.keras.backend.get_value(input)
  for category in categories:
    encoded_input.append(data==category)
  return np.array(encoded_input)

The problem is, when mapping the dataset and calling one_hot_encode:

ds = ds.map(lambda input, target: (input, tf.py_function(one_hot_encode,inp=[[1,2,3,4,5,6,7,8,9,10],target], Tout=tf.float32)))
ds = ds.map(lambda input, target: (input, tf.reshape(target, (10,))))

tensorflow will take forever to create an Iterator for this dataset e.g. when trying to access the data in a for loop:

for (input, target) in dataset:
 ...

enter image description here

But if I use tensorflows build in one hot encoder, everything works fine and tensorflow is fast.

ds = ds.map(lambda input, target: (input, tf.one_hot(target,10)))
ds = ds.map(lambda input, target: (input, tf.reshape(target, (10,))))

In both approaches, the dataset and all tensors have the same shape. Does anyone know of another method to access the value of a tensor in .map or why tensorflow becomes so slow?

0

There are 0 best solutions below