Problem
I'd like to compute some stats from request data grouped by values in top layer and values in nested layer. The main problem with explode-join and 3x groupby is the code with big data (100GB) is too slow.
Sample data:
import pyspark.sql.types as T
rows = [
{"id": 1, "typeId": 1, "items":[
{"itemType": 1,"flag": False,"event": None},
{"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
{"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
]},
{"id": 2, "typeId": 2, "items":None},
{"id": 3, "typeId": 1, "items":[
{"itemType": 1,"flag": False,"event": None},
{"itemType": 6,"flag": False,"event":[{"info1": ""}]},
{"itemType": 6,"flag": False,"event":None},
]},
{"id": 4, "typeId": 2, "items":[
{"itemType": 1,"flag": True,"event":[{"info1": ""}]},
]},
{"id": 5, "typeId": 3, "items":None},
]
schema = T.StructType([
T.StructField("id", T.IntegerType(), False),
T.StructField("typeId", T.IntegerType()),
T.StructField("items", T.ArrayType(T.StructType([
T.StructField("itemType", T.IntegerType()),
T.StructField("flag", T.BooleanType()),
T.StructField("event", T.ArrayType(T.StructType([
T.StructField("info1", T.StringType()),
]))),
])), True),
])
df = spark.createDataFrame(rows, schema)
df.printSchema()
their structure:
root
|-- id: integer (nullable = false)
|-- typeId: integer (nullable = true)
|-- items: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- itemType: integer (nullable = true)
| | |-- flag: boolean (nullable = true)
| | |-- event: array (nullable = true)
| | | |-- element: struct (containsNull = true)
| | | | |-- info1: string (nullable = true)
I'd like to perform these calculations for each typeid and by items.itemtype:
- total of rows (requests)
- total of rows (requests) if contains some item
- total of rows (requests) if contains some item with (items.flag==True)
- total of items
- total of flaged items (items.flag==True)
- total of events on items (sum(size("items.event")))
Code
Get total of request for every typeId is simple, in real analysis layer1_groups contain more category columns:
import pyspark.sql.functions as F
layer1_groups = ["typeId"]
# get count for groups in top layer
totaldf = df.groupby(layer1_groups).agg(F.count(F.lit(1)).alias("requests"))
For future computation (e.g. ratio with computed number on nested group), join these numbers to original dataframe:
df = df.join(totaldf, layer1_groups)
explode items to allow grouping by nested items.itemType
exploded_df = df.withColumn("I", F.explode_outer("items")).select("*","I.*").drop("items","I")
# add another info of item (number of events)
exploded_df = exploded_df.withColumn("eSize", F.greatest(F.size("event"), F.lit(0)))
grouping stats for every request (groupby "id") to obtain, because in future computation I want to count requests only if has flaged items, etc.:
layer2_groups = ["itemType"]
each_requests = exploded_df.groupby(["id", *layer1_groups, *layer2_groups]).agg(
F.first("requests").alias("requests"),
F.count(F.lit(1)).alias("ItemCount"),
F.sum(F.col("flag").cast(T.ByteType())).alias("fItemCount"),
F.sum("eSize").alias("eCount"),
)
Finish groups are without the "id" group:
# results without layer1 "id" to obtain resulsts
requests_results = each_requests.groupby([*layer1_groups, *layer2_groups]).agg(
F.first("requests").alias("requests"),
F.count_if(F.col("ItemCount")>0).alias("requestsWithItems"),
F.count_if(F.col("fItemCount")>0).alias("requestsWith_fItems"),
F.sum("ItemCount").alias("ItemCount"),
F.sum("fItemCount").alias("fItemCount"),
F.sum("eCount").alias("eCount"),
).show()
result is:
+------+--------+--------+-----------------+-------------------+---------+----------+------+
|typeId|itemType|requests|requestsWithItems|requestsWith_fItems|ItemCount|fItemCount|eCount|
+------+--------+--------+-----------------+-------------------+---------+----------+------+
| 1| 1| 2| 2| 0| 2| 0| 0|
| 1| 3| 2| 1| 1| 2| 2| 4|
| 1| 6| 2| 1| 0| 2| 0| 1|
| 2| 1| 2| 1| 1| 1| 1| 1|
| 2| NULL| 2| 1| 0| 1| NULL| 0|
| 3| NULL| 1| 1| 0| 1| NULL| 0|
+------+--------+--------+-----------------+-------------------+---------+----------+------+
Whole code
Gist: https://gist.github.com/vanheck/bfcadf7396d765ddd2fff5f544fd7cf2
Question
Is there some way to make faster this code? Or can I avoid explode function to obtain these stats?
You do not need to explode to get the statistics that you need. I tried this below on my local and it worked. I have pasted my result - please adjust as per your requirement.