Python (SciKeras): 'ValueError: Invalid parameter dropout_rate for estimator KerasRegressor.'

118 Views Asked by At

I was trying to build an LSTM model for Sales Forecasting using GridSearchCV. The following is my code:

    # Create LSTM Model
    def create_lstm_model():
        model_lstm = Sequential()
        model_lstm.add(LSTM(activation='tanh', 
                            recurrent_activation='sigmoid', 
                            recurrent_dropout=0, 
                            unroll=False, 
                            use_bias=True, 
                            input_shape=(X_train_lstm.shape[1], X_train_lstm.shape[2])))
        model_lstm.add(Dense(1))
        model_lstm.compile(loss='mean_squared_error', optimizer='adam')
        return model_lstm
    
    # Create the parameters for GridSearchCV
    parameters = {'batch_size': [10, 20, 40, 60, 80],
                  'epochs': [10, 50, 100],
                  'units': [32, 64, 128],
                  'dropout_rate': [0.2, 0.4, 0.6]}

    keras_model = KerasRegressor(build_fn = create_lstm_model,
                                 verbose=-1)
    
    lstm_model = GridSearchCV(estimator=keras_model,
                              param_grid=parameters,
                              cv = 4,
                              n_jobs = -1,
                              verbose=-1)

    lstm_model.fit(X_train_lstm, y_train)
    lstm_grid_params = lstm_model.best_estimator_  

    preds = lstm_grid_params.predict(X_test_lstm)

I got the following error when I ran the above code:

ValueError: Invalid parameter dropout_rate for estimator KerasRegressor.
This issue can likely be resolved by setting this parameter in the KerasRegressor constructor:
`KerasRegressor(dropout_rate=0.2)`
Check the list of available parameters with `estimator.get_params().keys()`

I understand that the problem can be fixed if I fix the dropout_rate to 0.2. However when I do this then I get the same error for units. How do I use GridSearchCV to get the optimal value for all these parameters?

0

There are 0 best solutions below