I am using R for deep learning with the MNIST dataset.
I have written this code to store the training and testing data, and define and fit the model:
library(keras)
#Obtain data
mnist <- dataset_mnist()
train_data <- mnist$train$x
train_labels <- mnist$train$y
test_data <- mnist$test$x
test_labels <- mnist$test$y
#Reshape & normalize
train_data <- array_reshape(train_data,c(nrow(train_data), 784))
train_data <- train_data / 255
test_data <- array_reshape(test_data,c(nrow(test_data), 784))
test_data <- test_data / 255
#One hot encoding train_labels <- to_categorical(train_labels, 10)
test_labels <- to_categorical(test_labels, 10)
#Model
model <- keras_model_sequential()
model %>% layer_dense(units=128,activation="relu", input_shape=c(784)) %>%
layer_dropout(rate=0.3) %>%
layer_dense(units=64,activation="relu") %>%
layer_dropout(rate=0.2) %>%
layer_dense(units=10,activation="softmax")
#Compile
model %>% compile(loss="categorical_crossentropy",
optimizer="rmsprop",
metrics="accuracy")
#Train
history <- model %>% fit(train_data,
train_labels,
epochs=10,
batch_size=784,
validation_split=0.2,
verbose=2)
#Evaluation and prediction
model %>% evaluate(test_data, test_labels)
pred <- model %>% predict(test_data)
print(table(Predicted=pred, Actual=test_labels))
When running it in R studio, the following error occurs:
ValueError: No gradients provided for any variable: (['dense_124/kernel:0', 'dense_124/bias:0', 'dense_123/kernel:0', 'dense_123/bias:0', 'dense_122/kernel:0', 'dense_122/bias:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'dense_124/kernel:0' shape=(784, 128) dtype=float32>), (None, <tf.Variable 'dense_124/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'dense_123/kernel:0' shape=(128, 64) dtype=float32>), (None, <tf.Variable 'dense_123/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'dense_122/kernel:0' shape=(64, 10) dtype=float32>), (None, <tf.Variable 'dense_122/bias:0' shape=(10,) dtype=float32>)).
I think the problem may be with the conflicting shapes of the input data and the input, but no idea how to solve this.
Thanks for help!