InvalidArgumentError: Graph execution error while running model.fit(...)

37 Views Asked by At

I'm writing a multi-label image segmentation model using FCN-16. When I run model.fit(), I get the following error. How do I inspect what's going wrong? Any notes on where to start, or what could possibly be causing an issue.

The error:

--------------------------------------------------------------------------- 
InvalidArgumentError                      Traceback (most recent call last) Cell In[8], line 35 32 STEPS_PER_EPOCH = TRAINSET_SIZE // BATCH_SIZE 33 VALIDATION_STEPS = VALIDSET_SIZE // BATCH_SIZE 35 model_history = model.fit(dataset['train'], epochs=EPOCHS, 36                           steps_per_epoch=STEPS_PER_EPOCH, 37                           validation_data = dataset["val"], 38                           validation_steps=VALIDATION_STEPS, 39                           callbacks = Callbacks)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs) 120     filtered_tb = _process_traceback_frames(e.__traceback__) 121     # To get the full stack trace, call: 122     # `keras.config.disable_traceback_filtering()` 123     raise e.with_traceback(filtered_tb) from None 124 finally: 125     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 51 try: 52   ctx.ensure_initialized() 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 54                                       inputs, attrs, num_outputs) 55 except core._NotOkStatusException as e: 56   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node ScatterNd defined at (most recent call last): File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main

File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code

File "/opt/conda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

File "/opt/conda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance

File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 701, in start

File "/opt/conda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run

File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue

File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 523, in process_one

File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell

File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request

File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 429, in do_execute

File "/opt/conda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

File "/opt/conda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

File "/tmp/ipykernel_33/2894299764.py", line 35, in <module>

File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 118, in error_handler

Here's some relevant snippets of the code:

def parse_image(image_path: str) -> dict:
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=N_CHANNELS)
    image = tf.image.convert_image_dtype(image, tf.uint8)
    
    print(image.shape)
    
    mask_path = tf.strings.regex_replace(image_path, "image", "image_gt")
    
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_jpeg(mask, channels=N_CHANNELS)
    
    masks = []
    
    # Modify this section to handle different labels
    label_values = [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]  # Define your label values here
    
    # Create numeric masks for each label:
    for label_color, label_value in zip(label_colors, label_values):
        mask_binary = tf.reduce_all(mask == label_color, axis=-1)
        mask_binary = tf.expand_dims(mask_binary, axis=-1)  # Expand dimensions to match mask
        
        mask_value = tf.where(mask_binary, tf.cast(label_value, tf.uint8), tf.zeros_like(tf.cast(mask_binary, tf.uint8))) 
        print(mask_value.shape)
        masks.append(mask_value)
    
    # Combine masks into a single tensor
    mask = tf.concat(masks, axis=-1)
    
    # problematic line here:
    # mask = tf.reduce_max(mask, axis=-1, keepdims=True)  # Ensure mask has only one channel

    return {'image': image, 'segmentation_mask': mask}

# Generate dataset variables
all_dataset = tf.data.Dataset.list_files(traindata_dir + "*.png", seed=SEED)
all_dataset = all_dataset.map(parse_image)

train_dataset = all_dataset.take(TRAINSET_SIZE + VALIDSET_SIZE)
val_dataset = train_dataset.skip(TRAINSET_SIZE)
train_dataset = train_dataset.take(TRAINSET_SIZE)
test_dataset = all_dataset.skip(TRAINSET_SIZE + VALIDSET_SIZE)

# Get VGG-16 network as backbone
vgg16_model = VGG16() # maybe try to use VGG18
# vgg16_model.summary()

# Define input shape
input_shape = (IMG_SIZE_2, IMG_SIZE_1, N_CHANNELS)

# Generate a new model using the VGG network
# Input
inputs = Input(input_shape)

# VGG network
vgg16_model = VGG16(include_top = False, weights = 'imagenet', input_tensor = inputs)

# Encoder Layers
pool3 = vgg16_model.get_layer("block3_pool").output         
pool4 = vgg16_model.get_layer("block4_pool").output         
pool5 = vgg16_model.get_layer("block5_pool").output

conv_6 = Conv2D(1024, (7,7), activation='relu', padding='same', name="conv_6")(pool5)
conv_7 = Conv2D(1024, (1, 1), activation='relu', padding='same', name="conv_7")(conv_6)
    
conv_8 = Conv2D(N_CLASSES, (1, 1), activation='relu', padding='same', name="conv_8")(pool4)
conv_9 = Conv2D(N_CLASSES, (1, 1), activation='relu', padding='same', name="conv_9")(pool3)

deconv_7 = Conv2DTranspose(N_CLASSES, kernel_size=(2,2), strides=(2,2))(conv_7)
add_1 = Add()([deconv_7, conv_8])
deconv_8 = Conv2DTranspose(N_CLASSES, kernel_size=(2,2), strides=(2,2))(add_1)

deconv_10 = Conv2DTranspose(N_CLASSES, kernel_size=(16,16), strides=(16,16))(add_1)
output_layer = Activation('softmax')(deconv_10)

model = Model(inputs=vgg16_model.input, outputs=output_layer)
model.summary()

# Set Variables
EPOCHS = 20
STEPS_PER_EPOCH = TRAINSET_SIZE // BATCH_SIZE
VALIDATION_STEPS = VALIDSET_SIZE // BATCH_SIZE

model_history = model.fit(dataset['train'], epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_data = dataset["val"],
                          validation_steps=VALIDATION_STEPS,
                          callbacks = Callbacks)
  • I inspected the shapes of the input image and masks, and they look fine. The expected output and target labels' shapes should match as well, and they do.
  • Checked to see if the model compiles as expected. and it does without any visible errors.
  • However, I'm not entirely sure what this error means, or what could be causing it which is why I'm asking for help here.
0

There are 0 best solutions below