I am trying to build EfficientNetV2 model with CBAM (Convolutional Block Attention Module) by using Keras.
I've implemented it myself, using this Github as a reference.
Could you confirm that the model I have built is correct?
Note: CBAM was introduced before GlobalAveragePooling and specifically used skip connections.
import tensorflow as tf
from tensorflow.keras import backend as K
def channel_attention_module(input: tf.keras.Model, ratio=8):
channel = input.shape[-1]
shared_dense_one = tf.keras.layers.Dense(channel // ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
shared_dense_two = tf.keras.layers.Dense(channel,
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
avg_pool = tf.keras.layers.GlobalAveragePooling2D()(input)
avg_pool = tf.keras.layers.Reshape((1, 1, channel))(avg_pool)
avg_pool = shared_dense_one(avg_pool)
avg_pool = shared_dense_two(avg_pool)
max_pool = tf.keras.layers.GlobalMaxPooling2D()(input)
max_pool = tf.keras.layers.Reshape((1, 1, channel))(max_pool)
max_pool = shared_dense_one(max_pool)
max_pool = shared_dense_two(max_pool)
x = tf.keras.layers.Add()([avg_pool, max_pool])
x = tf.keras.layers.Activation('sigmoid')(x)
return tf.keras.layers.multiply([input, x])
def spatial_attention_module(input: tf.keras.Model, kernel_size=7):
avg_pool = tf.keras.layers.Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(input)
max_pool = tf.keras.layers.Lambda(lambda x: K.max(x, axis=3, keepdims=True))(input)
x = tf.keras.layers.Concatenate(axis=3)([avg_pool, max_pool])
x = tf.keras.layers.Conv2D(filters=1,
kernel_size=kernel_size,
strides=1,
padding='same',
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=False)(x)
return tf.keras.layers.multiply([input, x])
h, w, d = (300, 300, 3)
base_model = tf.keras.applications.EfficientNetV2B3(
weights="imagenet",
include_top=False,
input_shape=(h,w,d),
include_preprocessing=False
)
x = base_model.output
skip_Z = x
x = channel_attention_module(x)
x = spatial_attention_module(x)
x = tf.keras.layers.Add()([x, skip_Z])
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(11, activation="softmax")(x)
model = tf.keras.models.Model(base_model.input, x)
And, below is output of model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 300, 300, 3 0 []
)]
stem_conv (Conv2D) (None, 150, 150, 40 1080 ['input_1[0][0]']
)
stem_bn (BatchNormalization) (None, 150, 150, 40 160 ['stem_conv[0][0]']
)
stem_activation (Activation) (None, 150, 150, 40 0 ['stem_bn[0][0]']
)
...
block6l_add (Add) (None, 10, 10, 232) 0 ['block6l_drop[0][0]',
'block6k_add[0][0]']
top_conv (Conv2D) (None, 10, 10, 1536 356352 ['block6l_add[0][0]']
)
top_bn (BatchNormalization) (None, 10, 10, 1536 6144 ['top_conv[0][0]']
)
top_activation (Activation) (None, 10, 10, 1536 0 ['top_bn[0][0]']
)
global_average_pooling2d (Glob (None, 1536) 0 ['top_activation[0][0]']
alAveragePooling2D)
global_max_pooling2d (GlobalMa (None, 1536) 0 ['top_activation[0][0]']
xPooling2D)
reshape (Reshape) (None, 1, 1, 1536) 0 ['global_average_pooling2d[0][0]'
]
reshape_1 (Reshape) (None, 1, 1, 1536) 0 ['global_max_pooling2d[0][0]']
dense (Dense) (None, 1, 1, 192) 295104 ['reshape[0][0]',
'reshape_1[0][0]']
dense_1 (Dense) (None, 1, 1, 1536) 296448 ['dense[0][0]',
'dense[1][0]']
add (Add) (None, 1, 1, 1536) 0 ['dense_1[0][0]',
'dense_1[1][0]']
activation (Activation) (None, 1, 1, 1536) 0 ['add[0][0]']
multiply (Multiply) (None, 10, 10, 1536 0 ['top_activation[0][0]',
) 'activation[0][0]']
lambda (Lambda) (None, 10, 10, 1) 0 ['multiply[0][0]']
lambda_1 (Lambda) (None, 10, 10, 1) 0 ['multiply[0][0]']
concatenate (Concatenate) (None, 10, 10, 2) 0 ['lambda[0][0]',
'lambda_1[0][0]']
conv2d (Conv2D) (None, 10, 10, 1) 98 ['concatenate[0][0]']
multiply_1 (Multiply) (None, 10, 10, 1536 0 ['multiply[0][0]',
) 'conv2d[0][0]']
add_1 (Add) (None, 10, 10, 1536 0 ['multiply_1[0][0]',
) 'top_activation[0][0]']
activation_1 (Activation) (None, 10, 10, 1536 0 ['add_1[0][0]']
)
global_average_pooling2d_1 (Gl (None, 1536) 0 ['activation_1[0][0]']
obalAveragePooling2D)
dense_2 (Dense) (None, 11) 16907 ['global_average_pooling2d_1[0][0
]']
==================================================================================================
Total params: 13,539,179
Trainable params: 13,429,963
Non-trainable params: 109,216
__________________________________________________________________________________________________