Polars groupby map UDF using multiple columns as parameter

153 Views Asked by At

I have a numba UDF:

@numba.jit(nopython=True)
def generate_sample_numba(cumulative_dollar_volume: np.ndarray, dollar_tau: Union[int, np.ndarray]) -> np.ndarray:
        """ Generate the sample using numba for speed.
        """
        covered_dollar_volume = 0
        bar_index = 0
        bar_index_array = np.zeros_like(cumulative_dollar_volume, dtype=np.uint32)
        
        if isinstance(dollar_tau, int):
            dollar_tau = np.array([dollar_tau] * len(cumulative_dollar_volume))

        for i in range(len(cumulative_dollar_volume)):
            bar_index_array[i] = bar_index
            if cumulative_dollar_volume[i] >= covered_dollar_volume + dollar_tau[i]:
                bar_index += 1
                covered_dollar_volume = cumulative_dollar_volume[i]
        return bar_index_array

The UDF takes two inputs:

  1. The cumulative_dollar_volume numpy array, which is essentially the groups in group_by
  2. The dollar_tau threshold, which is either an integer or numpy array.

In this question, I am particularly interested in the numpy array configuration. This post well explains the idea behind the generat_sample_numba function.

I want to achieve the same results from Pandas by using polars:

data["bar_index"] = data.groupby(["ticker", "date"]).apply(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].values, x["dollar_tau"].values)).explode().values.astype(int)

Apprently, the best option in Polars is by group_by().agg(pl.col().map_batehces():

cqt_sample = cqt_sample.with_columns(
    (pl.col("price") * pl.col("size")).alias("dollar_volume")).with_columns(
    pl.col("dollar_volume").cum_sum().over(["ticker", "date"]).alias("cumulative_dollar_volume"),
    pl.lit(1_000_000).alias("dollar_tau")
    )

(cqt_sample
    .group_by(["ticker", "date"])
    .agg(pl.col(["cumulative_dollar_volume", "dollar_tau"])
         .map_batches(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].to_numpy(), 1_000_000))
                      )#.alias("bar_index")
                      )#.explode("bar_index")

but the map_bathces() method seems to throw some strange results.`

However, when I use the integer dollar_tau with one input column it works fine:

(cqt_sample
    .group_by(["ticker", "date"])
    .agg(pl.col("cumulative_dollar_volume")
         .map_batches(lambda x: generate_sample_numba(x.to_numpy(), 1_000_000))
                      ).alias("bar_index")
                      ).explode("bar_index")
2

There are 2 best solutions below

8
Hericks On BEST ANSWER

As suggested in the comments, you'll need to call pl.Expr.map_batches on a struct column that contains all information needed by the function. Inside the function, you then pick the struct apart to obtain the desired information.

(
    data
    .group_by(["ticker", "date"])
    .agg(
        pl.struct("cumulative_dollar_volume", "dollar_tau").map_batches(lambda x: \
            generate_sample_numba(
                x.struct.field("cumulative_dollar_volume").to_numpy(),
                dollar_tau=x.struct.field("dollar_tau").to_numpy()
            )
        )
        .alias("bar_index")
    )
).explode("bar_index")
0
Kevin Li On

The origianl apporach is to apply the map_bathces method with group_by where a new DataFrame would be genreated. However, if we still want to access the information from the original DataFrame, the windows function is a lot better choice, instead of merging the newly generated column(s) to the original one:

udf_expression = (
    pl.struct(["cumulative_volume", self.volume_tau])
    .map_batches(lambda x: self.generate_sample_numba(x.struct.field("cumulative_volume").to_numpy(), x.struct.field(self.volume_tau).to_numpy()))
    .over([self.identifier_col, self.date_col])
    .alias("bar_index")
    )

data = (data
        .filter(self.cqt_filter)
        .with_columns(
            pl.col(self.size_col).cum_sum().over([self.identifier_col, self.date_col]).alias("cumulative_volume")
            )
        .with_columns(udf_expression))