Stacking in Displaying Self Attention weights in a bi-LSTM with attention mechanism

21 Views Asked by At

I am really stuck while trying to print the output of my Deep Network along with the attention weights of each sentence. I run a sentiment analysis problem, so I would like to check which word is more weighted to my classification task.
Here is me self attention Class:

class SelfAttention(tf.keras.layers.Layer):
  def __init__(self, mlp_layers=0, units=0, dropout_rate=0, return_attention=False, **kwargs):
    super(SelfAttention, self).__init__(**kwargs)
    self.mlp_layers = mlp_layers
    self.mlp_units = units
    self.return_attention = return_attention
    self.dropout_rate = dropout_rate
    self.attention_mlp = self.build_mlp()

  def build_mlp(self):
    mlp = tf.keras.Sequential()
    for i in range(self.mlp_layers):
      mlp.add(tf.keras.layers.Dense(self.mlp_units, activation='relu'))
      mlp.add(tf.keras.layers.Dropout(self.dropout_rate))
    mlp.add(tf.keras.layers.Dense(1))
    return mlp

  def call(self, x, mask=None):
    a = self.attention_mlp(x)
    a = tf.squeeze(a, axis=-1)

    if mask is not None:
      mask = tf.keras.backend.cast(mask, tf.keras.backend.floatx())
      a -= 100000.0 * (1.0 - mask)

    a = tf.keras.backend.expand_dims(tf.keras.backend.softmax(a, axis=-1))
    weighted_input = x * a
    result = tf.keras.backend.sum(weighted_input, axis=1)

    if self.return_attention:
      return [result, a]
    return result

And the code for my network:

import tensorflow as tf
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

LSTM_SIZE = 300
DENSE = 1000
MAX_WORDS = 100000
MAX_SEQUENCE_LENGTH = 128
EMBEDDING_DIM = 300

# Define input layer
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)

# Add layers using functional API
x = vectorizer(inputs)
x = tf.keras.layers.Embedding(MAX_WORDS, EMBEDDING_DIM, weights=[embedding_matrix],
                              input_length=MAX_SEQUENCE_LENGTH, mask_zero=True, trainable=False)(x)
x = tf.keras.layers.Dropout(0.33)(x)
# add a bidirectional lstm layer with 0.33 variational (recurrent) dropout
x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(LSTM_SIZE, return_sequences=True, recurrent_dropout=0.33))(x)
x = tf.keras.layers.Dropout(0.33)(x)
# # add a second bidirectional lstm layer with 0.33 variational (recurrent) dropout
# x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(LSTM_SIZE, return_sequences=True, recurrent_dropout=0.33))(x)
x = tf.keras.layers.Dropout(0.33)(x)
x, attention_weights = SelfAttention(mlp_layers=0, return_attention=True)(x)

x = tf.keras.layers.Dense(units=DENSE, activation='relu')(x)
x = tf.keras.layers.Dropout(0.33)(x)
outputs = tf.keras.layers.Dense(len(np.unique(y_train)), activation='softmax')(x)

# Define the model
model2 = tf.keras.Model(inputs=inputs, outputs=[outputs, attention_weigths])

print(model2.summary())

model2.compile(loss='categorical_crossentropy',
               optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
               metrics=["categorical_accuracy"])

# Add an early stopping callback to stop the epochs when we catch the best validation loss
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)
# Callback to save the Keras model or model weights at some frequency.
MLP_checkpoint = ModelCheckpoint(
    'checkpoints/weights1',
    monitor='val_loss',
    mode='min',
    verbose=2,
    save_best_only=True,
    save_weights_only=True,
    save_format='tf'
)

# Train the model
history2 = model2.fit(np.array(X_train), y_train_1_hot,
                      validation_data=(np.array(X_val), y_val_1_hot),
                      batch_size=256,
                      epochs=10,
                      shuffle=True,
              callbacks=[Metrics(valid_data=(np.array(X_val), y_val_1_hot)),
                        early_stopping_callback,
                        MLP_checkpoint]
                     )

The error I get is the following:

ValueError: Shapes (None, 2) and (None, 128) are incompatible

I have made reshapes concatenates however the problem still remains.

solve the problem of my self attention mechanism

0

There are 0 best solutions below