Find the closest value of each value in a column compared to another column in the same PySpark dataframe

91 Views Asked by At

We have a PySpark dataframe containing rate codes that we have to use to give discounted offers to our customers.

enter image description here

-ratecode - Actual rate code
-weeklyrate - weekly dollar amount that the customer will pay
-area - area of residence
-frequency -
-offer1 - The first discounted offer to customer
-offer2 - The second discounted offer to customer

The problem is to find the closest "ratecode" corresponding to "offer1" (and save it as "offer1Ratecode") and "offer2" (saving as "offer2Ratecode").

Explanation:

  1. for the "offer1" = 4.4 , the "offer1Ratecode" is R1, because the closest "weeklyrate" to 4.4 is 5.5 and 5.5 corresponds to "ratecode" R1
  2. for the "offer1" = 6 , the "offer1Ratecode" is R2, because the closest "weeklyrate" to 6 is 6.2 and 6.2 corresponds to "ratecode" R2
1

There are 1 best solutions below

3
ZygD On BEST ANSWER

Input:

df = spark.createDataFrame(
    [('R1', 5.5, 4.4, 3.85),
     ('R2', 6.2, 4.96, 4.34),
     ('R3', 7.5, 6.0, 5.25),
     ('R4', 5.6, 4.48, 3.92),
     ('R5', 7.3, 5.84, 5.11),
     ('R6', 8.4, 6.72, 5.88),
     ('R7', 9.1, 7.28, 6.37),
     ('R8', 6.8, 5.44, 4.76)],
    ['ratecode', 'weeklyrate', 'offer1', 'offer2'])

One way would be using crossJoin and groupBy:

from pyspark.sql import functions as F

def closest(col):
    return F.array_sort(F.collect_list(F.struct(
        F.abs(F.col(f'b.{col}') - F.col('a.weeklyrate')).alias('diff'),
        'a.weeklyrate',
        'a.ratecode',
    )))[0]['ratecode'].alias(f'{col}Ratecode')

df = df.select('weeklyrate', 'ratecode').alias('a').crossJoin(df.alias('b'))
df = df.groupBy(*[f'b.{c}' for c in df.select('b.*').columns]).agg(
    closest('offer1'),
    closest('offer2'),
)
df.show()
# +--------+----------+------+------+--------------+--------------+
# |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
# +--------+----------+------+------+--------------+--------------+
# |      R3|       7.5|   6.0|  5.25|            R2|            R1|
# |      R2|       6.2|  4.96|  4.34|            R1|            R1|
# |      R1|       5.5|   4.4|  3.85|            R1|            R1|
# |      R4|       5.6|  4.48|  3.92|            R1|            R1|
# |      R7|       9.1|  7.28|  6.37|            R5|            R2|
# |      R6|       8.4|  6.72|  5.88|            R8|            R4|
# |      R8|       6.8|  5.44|  4.76|            R1|            R1|
# |      R5|       7.3|  5.84|  5.11|            R4|            R1|
# +--------+----------+------+------+--------------+--------------+

Another could be using window functions and transform:

from pyspark.sql import functions as F, Window as W

def closest(col):
    return F.array_sort(F.transform(
        F.collect_list(F.struct('weeklyrate', 'ratecode')).over(W.orderBy()),
        lambda x: F.struct(
            F.abs(F.col(col) - x['weeklyrate']).alias('diff'),
            x['weeklyrate'].alias('weeklyrate'),
            x['ratecode'].alias('ratecode'),
        )
    ))[0]['ratecode'].alias(f'{col}Ratecode')

df = df.select('*', closest('offer1'), closest('offer2'))
df.show()
# +--------+----------+------+------+--------------+--------------+
# |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
# +--------+----------+------+------+--------------+--------------+
# |      R1|       5.5|   4.4|  3.85|            R1|            R1|
# |      R2|       6.2|  4.96|  4.34|            R1|            R1|
# |      R3|       7.5|   6.0|  5.25|            R2|            R1|
# |      R4|       5.6|  4.48|  3.92|            R1|            R1|
# |      R5|       7.3|  5.84|  5.11|            R4|            R1|
# |      R6|       8.4|  6.72|  5.88|            R8|            R4|
# |      R7|       9.1|  7.28|  6.37|            R5|            R2|
# |      R8|       6.8|  5.44|  4.76|            R1|            R1|
# +--------+----------+------+------+--------------+--------------+