Given: A big Dataset (1 billion+ records) from a DeltaTable on Databricks
I want to partition this Dataset in +- 1000 different partitions, dependent on some properties of each record. Then, I want to fit a Spark ML Pipeline for each of these partitions.
My first idea would be to use Dataset.groupByKey, and then mapGroups. However, the latter function provides me an Iterator instead of a Dataset. PipelineModel.fit only accepts a Dataset as input.
dataset.groupByKey(item => (item.propertyOne, item.propertyTwo))(product)
.mapGroups((key, group) => {
val pipeline: Pipeline = // ...
pipeline.fit(group) // Doesn't work since group is of type Iterator, not Dataset
})(product)
.foreach(_.save(some_path))
The alternative implementation I have now is to iterate and each time filter on the original Dataset, which I then pass to the PipelineModel.fit method. This works, but is awkwardly slow. I played around with multithreading this, but to no avail.
// 1000 possible values
allPossiblePropertyValues.forEach { v =>
pipeline.fit(dataset.filter(_.property == v)).save(some_path)
}