In Polars, how do you multiply a column of floats with a column of lists?

95 Views Asked by At

Given an example dataframe where we have column 'b' containing lists, and each list has the same length (so it also could be converted to arrays)

df_test = pl.DataFrame({'a': [1., 2., 3.], 'b': [[2,2,2], [3,3,3], [4,4,4]]})
df_test
shape: (3, 2)
┌─────┬───────────┐
│ a   ┆ b         │
│ --- ┆ ---       │
│ f64 ┆ list[i64] │
╞═════╪═══════════╡
│ 1.0 ┆ [2, 2, 2] │
│ 2.0 ┆ [3, 3, 3] │
│ 3.0 ┆ [4, 4, 4] │
└─────┴───────────┘

How do I end up with

shape: (3, 3)
┌─────┬───────────┬────────────────────┐
│ a   ┆ b         ┆ new                │
│ --- ┆ ---       ┆ ---                │
│ f64 ┆ list[i64] ┆ list[f64]          │
╞═════╪═══════════╪════════════════════╡
│ 1.0 ┆ [2, 2, 2] ┆ [2.0, 2.0, 2.0]    │
│ 2.0 ┆ [3, 3, 3] ┆ [6.0, 6.0, 6.0]    │
│ 3.0 ┆ [4, 4, 4] ┆ [12.0, 12.0, 12.0] │
└─────┴───────────┴────────────────────┘

without using map_rows?

The best way I could think of was to use map_rows, which is like apply in pandas. Not really the most efficient thing according to docs but it works:

df_temp = df_test.map_rows(lambda x: ([x[0] * i for i in x[1]],))
df_temp.columns = ['new']
df_test = df_test.hstack(df_temp)
2

There are 2 best solutions below

2
ouroboros1 On

Edit: adjusted answer to make sure that it works with duplicate values in column 'a'.


Here's one approach:

Data

N.B. Below changing 'a' from [1., 2., 3.] to [1., 1., 3.] to exemplify the need for an extra temporary column 'idx' for the groupby.

import polars as pl

# changing 'a' from `[1., 2., 3.]` to `[1., 1., 3.]` to exemplify need temp `idx`
df_test = pl.DataFrame({'a': [1., 1., 3.], 
                        'b': [[2,2,2], [3,3,3], [4,4,4]]})
df_test

shape: (3, 2)
┌─────┬───────────┐
│ a   ┆ b         │
│ --- ┆ ---       │
│ f64 ┆ list[i64] │
╞═════╪═══════════╡
│ 1.0 ┆ [2, 2, 2] │
│ 1.0 ┆ [3, 3, 3] │
│ 3.0 ┆ [4, 4, 4] │
└─────┴───────────┘

Code

df_new = (
    df_test.with_columns(idx=pl.arange(0, pl.len()))
           .explode('b')
           .with_columns(new=(pl.col('a') * pl.col('b')))
           .group_by(['idx', 'a'], maintain_order=True)
           .agg(pl.col("b"), pl.col("new"))
           .drop('idx')
)

df_new

shape: (3, 3)
┌─────┬───────────┬────────────────────┐
│ a   ┆ b         ┆ new                │
│ --- ┆ ---       ┆ ---                │
│ f64 ┆ list[i64] ┆ list[f64]          │
╞═════╪═══════════╪════════════════════╡
│ 1.0 ┆ [2, 2, 2] ┆ [2.0, 2.0, 2.0]    │
│ 1.0 ┆ [3, 3, 3] ┆ [3.0, 3.0, 3.0]    │
│ 3.0 ┆ [4, 4, 4] ┆ [12.0, 12.0, 12.0] │
└─────┴───────────┴────────────────────┘

Explanation

  • First, create a column 'idx' (using pl.DataFrame.with_columns, pl.arange, and pl.len) to keep track of each row. I.e., we use this column to differentiate between rows that have the same value in 'a'.
  • Now, use pd.DataFrame.explode to get the list values for 'b' into separate rows.
  • Next, chain pl.DataFrame.with_columns to multiply column 'a' by column 'b', assigning the result to 'new'.
  • Finally, we want to get back the lists: use pl.DataFrame.group_by on columns 'idx' and 'a', adding maintain_order=True to keep the data in the correct order, and apply groupby.agg on columns 'b' and 'new'.
  • Clean up by dropping 'idx' (using pl.DataFrame.drop).
3
Hericks On

Unfortunately, polars doesn't support referencing named columns within pl.Expr.list.eval. Otherwise, that would've been the go-to solution.

I think, the solution by @ouroboros1 is already on the right track by exploding the column, performing the operation, and imploding again. However, it can be simplified quite a bit as follows.

(
    df_test
    .with_columns(
        (
            pl.col("b").explode() * pl.col("a")
        )
        .implode().over(pl.int_range(pl.len()))
        .alias("new")
    )
)
shape: (3, 3)
┌─────┬───────────┬────────────────────┐
│ a   ┆ b         ┆ new                │
│ --- ┆ ---       ┆ ---                │
│ f64 ┆ list[i64] ┆ list[f64]          │
╞═════╪═══════════╪════════════════════╡
│ 1.0 ┆ [2, 2, 2] ┆ [2.0, 2.0, 2.0]    │
│ 2.0 ┆ [3, 3, 3] ┆ [6.0, 6.0, 6.0]    │
│ 3.0 ┆ [4, 4, 4] ┆ [12.0, 12.0, 12.0] │
└─────┴───────────┴────────────────────┘

Especially, we can avoid explicitly creating and dropping the index column as well as the pl.DataFrame.group_by().agg() construct, which would become more tedious to handle if the DataFrame had more columns.