Apache Spark RandomForestClassifier Predict label for single user input

21 Views Asked by At

I have following dataset, where only title will be features and only category_id (already int) will be label. Ignore category_text for now.

category_id,title,category_text
12321332,"drill bit","drilling"
23432212,"class plug","electrical tools"
34567789,"laptop","computers"

I'm able to train it as follows.

But prediction where I can't make it work. I cannot use tainandtest variable as I'm not just testing the model, I wanted to use it in production like scenario where I send user input like drill bit to model and expecting it to return category_id: 12321332

Problem1: When I try to build my own vector for the input drill bit and try in predict I'm getting Index 9 out of bounds [0, 1)

Problem2: The Predict method returning double, I can't find a good documentation for Java to predict label as string, appreciate any insight on it.

Fully runnable code

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class SearchML {


    public static void main(String...s) {
        SparkSession spark = SparkSession.builder()
                .master("local")
                .appName("RandomForestClassifierExample")
                .getOrCreate();
        StructType schema = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("category_id", DataTypes.IntegerType, false),
                DataTypes.createStructField("title", DataTypes.StringType, false),
                DataTypes.createStructField("category_text", DataTypes.StringType, false)
        });

        var dataFrame = spark.read().format("csv")
        .option("header", "true")
        .option("delimiter", ",")
        .option("mode", "DROPMALFORMED")
        .option("quote", "\"")
        .schema(schema)
        .load("csv_entry_slim.csv")
        .cache();

        dataFrame = dataFrame.groupBy("title", "category_id").agg(functions.collect_list("title").alias("titleArray"));

        CountVectorizer cv = new CountVectorizer()
                .setInputCol("titleArray")
                .setOutputCol("features");
        cv.setMaxDF(1);
        cv.setVocabSize(5000);
        dataFrame = cv.fit(dataFrame).transform(dataFrame);

        var indexer = new StringIndexer()
              .setInputCol("category_id")
              .setOutputCol("label");
        dataFrame = indexer.setHandleInvalid("skip").fit(dataFrame).transform(dataFrame);

        var seed = 5043;
        var tainAndTest = dataFrame.randomSplit(new double[]{0.99999, 0.00001}, seed);

        // train Random Forest model with training data set
        var randomForestClassifier = new RandomForestClassifier()
          .setImpurity("gini")
          .setMaxDepth(3)
          .setNumTrees(20)
          .setFeatureSubsetStrategy("auto")
          .setMaxBins(10)
          .setSeed(seed);
        var randomForestModel = randomForestClassifier.fit(tainAndTest[0]);


        // PREDICTION WHERE I CAN'T MAKE IT WORK. I CANNOT USE tainAndTest AS I'M NOT JUST TESTING THE MODEL, I WANTED TO USE IT IN PRODUCTION LIKE SCENARIO WHERE I SEND USER INPUT LIKE drill bit TO MODEL AND EXPECTING IT TO RETURN CATEGORY_ID: 12321332
        List<String> stringAsList = new ArrayList<>();
        stringAsList.add("drill bit");

        StructType schemaTest = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("title", DataTypes.StringType, false)
        });

        JavaSparkContext sparkContext = new JavaSparkContext(spark.sparkContext());

        JavaRDD<org.apache.spark.sql.Row> rowRDD = sparkContext.parallelize(stringAsList).map((String row) -> RowFactory.create(row));
        Dataset<org.apache.spark.sql.Row> userInputDataFrame = spark.sqlContext().createDataFrame(rowRDD, schemaTest).toDF();

        userInputDataFrame = userInputDataFrame.groupBy("title").agg(functions.collect_list("title").alias("titleArray"));
        userInputDataFrame = cv.fit(userInputDataFrame).transform(userInputDataFrame);
        System.out.println(userInputDataFrame);

        Vector v = (Vector) scala.collection.JavaConverters.seqAsJavaList(userInputDataFrame.collectAsList().get(0).toSeq()).get(2);
        var testRes = randomForestModel.predict(v);

        System.out.println(testRes);
    }
}
0

There are 0 best solutions below