How to expand the dimensions of a ragged tensor in keras?

95 Views Asked by At

my goal: Using keras, I want to add a (non-ragged) dimension to a ragged tensor:

  • My initial ragged tensor has shape [batch size, (# time steps), # features].
  • I want a final ragged tensor with shape [batch size, 1, (# time steps), # features].

(My motivation for this is to do temporal convolution on each sample, and so if someone knows how to do that, I would appreciate the knowledge.)

what I have tried: I have tried to use tf.map_fn with a Lambda layer that calls tf.expand_dims, but an error is thrown about the ragged dimension sizes being incompatible. Adding tf.TensorSpec as the fn_output_signature did not solve the problem. Adding a tf.RaggedTensorSpec as the fn_output_signature would not make sense (and also did not work) since the elements of the ragged tensor are tensors, not ragged tensors.

reprex (using tensorflow 2.15.0, python 3.11):

import tensorflow as tf

num_features = 4
x = tf.ragged.constant([
    [[0,0,0,0],[0,0,0,0],[1,0,0,0],[0,0,0,0]],
    [[0,0,0,0],[0,0,0,0],[0,0,0,0],[1,0,0,0],[0,0,0,0]]],
    dtype = tf.float32,
    inner_shape=(num_features,))

expandDims = tf.keras.layers.Lambda(
    lambda x: tf.expand_dims(x,axis=0))

# Would like y to be a ragged tensor with shape (2,1,None,4).
# Throws error.
y = tf.map_fn(expandDims,x)

# Also throws error.
#y = tf.map_fn(expandDims,x,fn_output_signature = tf.TensorSpec(shape = (1,None,num_features)))

stdout:

2024-01-07 09:52:03.342983: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at ragged_tensor_from_variant_op.cc:333 : INVALID_ARGUMENT: All flat_values must have compatible shapes.  Shape at index 0: [4,4].  Shape at index 1: [5,4].  If you are using tf.map_fn, then you may need to specify an explicit fn_output_signature with appropriate ragged_rank, and/or convert output tensors to RaggedTensors.
Traceback (most recent call last):
  File "c:\Users\apples\Documents\tensforflow probability course\test.py", line 18, in <module>
    y = tf.map_fn(expandDims,x)
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\util\deprecation.py", line 660, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\util\deprecation.py", line 588, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\map_fn.py", line 637, in map_fn_v2
    return map_fn(
           ^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\util\deprecation.py", line 588, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\map_fn.py", line 516, in map_fn
    result_flat = _result_batchable_to_flat(result_batchable,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\map_fn.py", line 607, in _result_batchable_to_flat
    spec._batch(batch_size)._from_compatible_tensor_list(
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\ragged\ragged_tensor.py", line 2601, in _from_compatible_tensor_list   
    result = RaggedTensor._from_variant(  # pylint: disable=protected-access
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\ragged\ragged_tensor.py", line 2028, in _from_variant
    result = gen_ragged_conversion_ops.ragged_tensor_from_variant(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\ops\gen_ragged_conversion_ops.py", line 77, in ragged_tensor_from_variant  
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Program Files\Python\Python311\Lib\site-packages\tensorflow\python\framework\ops.py", line 5883, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__RaggedTensorFromVariant_output_ragged_rank_1_device_/job:localhost/replica:0/task:0/device:CPU:0}} All flat_values must have compatible shapes.  Shape at index 0: [4,4].  Shape at index 1: [5,4].  If you are using tf.map_fn, then you may need to specify an explicit fn_output_signature with appropriate ragged_rank, and/or convert output tensors to RaggedTensors. [Op:RaggedTensorFromVariant] name:

EDIT:

I have given up on trying to use ragged tensors as inputs to tf.keras.layers.Conv1D, and have instead decided to pad the ragged tensors (such that they become normal tensors).

0

There are 0 best solutions below