Calculate Rolling weighted mean pyspark with Rangebetween funtion

33 Views Asked by At

I want to calculate the rolling weighted average same as python df.ewm function with span=15 days (however df.ewm use 15 rows not 15 days from date in current row). But there is no direct function to do this in pyspark. I also looked at some of the solution on internet but most use rowbetween and there is no normalization for missing date in between. is there a way to do this in pyspark? currently I am using following code with rowsbetween which only works if there is no missing date in-between. Also, I also want to neglect weekdays while calculating weighted average.

def ewa(columnPartitionValues):
  columnPartitionValues = columnPartitionValues[::-1]
  alpha=(2/(5+1))
  sum=0
  weight=0
  for i, v in enumerate(columnPartitionValues):
    sum=sum+((1-alpha)**i)*v
    weight=weight+((1-alpha)**(i))
  return sum/weight
  
ewa_udf = udf(ewa, FloatType()) 

data = [(1, '2024-02-20', 10),
        (1, '2024-02-21', 20),
        (1, '2024-02-22', 10),
        (1, '2024-02-23', 40),
        (1, '2024-02-24', 20),
        (1, '2024-02-25', 60),
        (1, '2024-02-26', 70),
        (1, '2024-02-27', 80),
        (2, '2024-02-28', 90),
        (2, '2024-02-29', 100),
        (2, '2024-03-01', 110),
        (2, '2024-03-02', 120),
        (2, '2024-03-03', 130),
        (2, '2024-03-04', 140)]

columns = ["id", "date", "value"]

# Create DataFrame
df = spark.createDataFrame(data, columns)

df = df.withColumn("date", col("date").cast("date"))

wind=Window.partitionBy('id').orderBy('date').rowsBetween(-4,0)
df=df.withColumn('col_list',collect_list('value').over(wind))

display(df)
df1=df.withColumn('ewma',ewa_udf('col_list'))
display(df1)

Here I have only used 4 row window for example. Is there a way to do it with rangebetween and also consider the missing date except saturday & sunday?

0

There are 0 best solutions below