I have a small problem with Java and machine learning. I have trained a model with Keras and it works as expected when I use Python to predict images.
The shape on which the model is trained was [ width, height, RGB ].
But when I load an image in Java I got [ RGB, width, height] - so I try to use .reshape() to change the shape but I clearly mess there something up because all predictions are wrong afterwards:
ResizeImageTransform rit = new ResizeImageTransform(128, 128);
NativeImageLoader loader = new NativeImageLoader(128, 128, 3, rit);
INDArray features = loader.asMatrix(f); // GIVES ME A SHAPE OF 1, 3, 128, 128
features = features.reshape(1, 128, 128, 3); // GIVES ME THE SHAPE 1, 128, 128, 3 AS NEEDED
INDArray[] prediction = model.output(features); // all predictions wrong
I am no Java developer and I try to get alon with the documentation but here I clearly overlook something. Maybe someone here can give a tip what I am doing wrong...
So now I get at least 136 images of my test-set flagged. The Python version flags 195 images...
So I guess the normalisation is a problem. I train the model with:
And I use
before the prediction in the test script.
In Java I use
but I am not sure if the normalisation is the issue or if I have srewed up the parameters for .permute()...
Any suggestions?