Trying EfficientNetV2 with CBAM

178 Views Asked by At

I am trying to build EfficientNetV2 model with CBAM (Convolutional Block Attention Module) by using Keras.

I've implemented it myself, using this Github as a reference.

Could you confirm that the model I have built is correct?

Note: CBAM was introduced before GlobalAveragePooling and specifically used skip connections.

import tensorflow as tf
from tensorflow.keras import backend as K

def channel_attention_module(input: tf.keras.Model, ratio=8):
    channel = input.shape[-1]

    shared_dense_one = tf.keras.layers.Dense(channel // ratio,
                                             activation='relu',
                                             kernel_initializer='he_normal',
                                             use_bias=True,
                                             bias_initializer='zeros')
    shared_dense_two = tf.keras.layers.Dense(channel,
                                             kernel_initializer='he_normal',
                                             use_bias=True,
                                             bias_initializer='zeros')

    avg_pool = tf.keras.layers.GlobalAveragePooling2D()(input)
    avg_pool = tf.keras.layers.Reshape((1, 1, channel))(avg_pool)
    avg_pool = shared_dense_one(avg_pool)
    avg_pool = shared_dense_two(avg_pool)

    max_pool = tf.keras.layers.GlobalMaxPooling2D()(input)
    max_pool = tf.keras.layers.Reshape((1, 1, channel))(max_pool)
    max_pool = shared_dense_one(max_pool)
    max_pool = shared_dense_two(max_pool)

    x = tf.keras.layers.Add()([avg_pool, max_pool])
    x = tf.keras.layers.Activation('sigmoid')(x)

    return tf.keras.layers.multiply([input, x])

def spatial_attention_module(input: tf.keras.Model, kernel_size=7):
    avg_pool = tf.keras.layers.Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(input)
    max_pool = tf.keras.layers.Lambda(lambda x: K.max(x, axis=3, keepdims=True))(input)
    x = tf.keras.layers.Concatenate(axis=3)([avg_pool, max_pool])
    x = tf.keras.layers.Conv2D(filters=1,
                               kernel_size=kernel_size,
                               strides=1,
                               padding='same',
                               activation='sigmoid',
                               kernel_initializer='he_normal',
                               use_bias=False)(x)

    return tf.keras.layers.multiply([input, x])


h, w, d = (300, 300, 3)

base_model = tf.keras.applications.EfficientNetV2B3(
    weights="imagenet",
    include_top=False,
    input_shape=(h,w,d),
    include_preprocessing=False
)
x = base_model.output
skip_Z = x
x = channel_attention_module(x)
x = spatial_attention_module(x)
x = tf.keras.layers.Add()([x, skip_Z])
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(11, activation="softmax")(x)
model = tf.keras.models.Model(base_model.input, x)

And, below is output of model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 300, 300, 3  0           []                               
                                )]                                                                
                                                                                                  
 stem_conv (Conv2D)             (None, 150, 150, 40  1080        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 stem_bn (BatchNormalization)   (None, 150, 150, 40  160         ['stem_conv[0][0]']              
                                )                                                                 
                                                                                                  
 stem_activation (Activation)   (None, 150, 150, 40  0           ['stem_bn[0][0]']                
                                )                                                                 

 ...

 block6l_add (Add)              (None, 10, 10, 232)  0           ['block6l_drop[0][0]',           
                                                                  'block6k_add[0][0]']            
                                                                                                  
 top_conv (Conv2D)              (None, 10, 10, 1536  356352      ['block6l_add[0][0]']            
                                )                                                                 
                                                                                                  
 top_bn (BatchNormalization)    (None, 10, 10, 1536  6144        ['top_conv[0][0]']               
                                )                                                                 
                                                                                                  
 top_activation (Activation)    (None, 10, 10, 1536  0           ['top_bn[0][0]']                 
                                )                                                                 
                                                                                                  
 global_average_pooling2d (Glob  (None, 1536)        0           ['top_activation[0][0]']         
 alAveragePooling2D)                                                                              
                                                                                                  
 global_max_pooling2d (GlobalMa  (None, 1536)        0           ['top_activation[0][0]']         
 xPooling2D)                                                                                      
                                                                                                  
 reshape (Reshape)              (None, 1, 1, 1536)   0           ['global_average_pooling2d[0][0]'
                                                                 ]                                
                                                                                                  
 reshape_1 (Reshape)            (None, 1, 1, 1536)   0           ['global_max_pooling2d[0][0]']   
                                                                                                  
 dense (Dense)                  (None, 1, 1, 192)    295104      ['reshape[0][0]',                
                                                                  'reshape_1[0][0]']              
                                                                                                  
 dense_1 (Dense)                (None, 1, 1, 1536)   296448      ['dense[0][0]',                  
                                                                  'dense[1][0]']                  
                                                                                                  
 add (Add)                      (None, 1, 1, 1536)   0           ['dense_1[0][0]',                
                                                                  'dense_1[1][0]']                
                                                                                                  
 activation (Activation)        (None, 1, 1, 1536)   0           ['add[0][0]']                    
                                                                                                  
 multiply (Multiply)            (None, 10, 10, 1536  0           ['top_activation[0][0]',         
                                )                                 'activation[0][0]']             
                                                                                                  
 lambda (Lambda)                (None, 10, 10, 1)    0           ['multiply[0][0]']               
                                                                                                  
 lambda_1 (Lambda)              (None, 10, 10, 1)    0           ['multiply[0][0]']               
                                                                                                  
 concatenate (Concatenate)      (None, 10, 10, 2)    0           ['lambda[0][0]',                 
                                                                  'lambda_1[0][0]']               
                                                                                                  
 conv2d (Conv2D)                (None, 10, 10, 1)    98          ['concatenate[0][0]']            
                                                                                                  
 multiply_1 (Multiply)          (None, 10, 10, 1536  0           ['multiply[0][0]',               
                                )                                 'conv2d[0][0]']                 
                                                                                                  
 add_1 (Add)                    (None, 10, 10, 1536  0           ['multiply_1[0][0]',             
                                )                                 'top_activation[0][0]']         
                                                                                                  
 activation_1 (Activation)      (None, 10, 10, 1536  0           ['add_1[0][0]']                  
                                )                                                                 
                                                                                                  
 global_average_pooling2d_1 (Gl  (None, 1536)        0           ['activation_1[0][0]']           
 obalAveragePooling2D)                                                                            
                                                                                                  
 dense_2 (Dense)                (None, 11)           16907       ['global_average_pooling2d_1[0][0
                                                                 ]']                              
                                                                                                  
==================================================================================================
Total params: 13,539,179
Trainable params: 13,429,963
Non-trainable params: 109,216
__________________________________________________________________________________________________
0

There are 0 best solutions below