Left join of 2 exploded polars columns

89 Views Asked by At

Considering

import polars as pl

df = pl.DataFrame({"a": [
    [1, 2],
    [3]],
    "b": [
        [{"id": 1, "x": 1}, {"id": 3, "x": 3}],
        [{"id": 3, "x": 4}]]})

That looks like:

+------+---------------------+
|a     |b                    |
+------+---------------------+
|[1, 2]|[{1,1}, {3,3}]|
|[3]   |[{3,4}]              |
+------+---------------------+

How to

  • get one row for each flatten a element and
  • if the list of dict in b contains the a element as id
  • then have the corresponding x value in the column b
  • otherwise b should be null

Current approach

.explode both a and b and .filter (INNER JOIN):

df.explode("a").explode("b").filter(
    pl.col("a") == pl.col("b").struct.field('id')
).select(
    pl.col("a"),
    pl.col("b").struct.field("x")
)

Unfortunately I get only the (expected):

+-+----+
|a|b   |
+-+----+
|1|1   |
|3|4   |
+-+----+

Instead of the full "LEFT JOIN" I am aiming to:

+-+----+
|a|b   |
+-+----+
|1|1   |
|2|null|
|3|4   |
+-+----+

How to efficiently get the desired result when the DataFrame is structured like that?

2

There are 2 best solutions below

1
jqurious On BEST ANSWER

As you have mentioned it, you can also use a left-join explicitly:

df_a = df.explode("a")
df_b = df_a.explode("b").unnest("b").filter(pl.col("a") == pl.col("id"))

df_a.select("a").join(df_b, on="a", how="left")
shape: (3, 3)
┌─────┬──────┬──────┐
│ a   ┆ id   ┆ x    │
│ --- ┆ ---  ┆ ---  │
│ i64 ┆ i64  ┆ i64  │
╞═════╪══════╪══════╡
│ 1   ┆ 1    ┆ 1    │
│ 2   ┆ null ┆ null │
│ 3   ┆ 3    ┆ 4    │
└─────┴──────┴──────┘

With regards to your current filtering logic:

You want to pick a single row if there are no matches per group e.g.

is_id = pl.col("a") == pl.col("id")

is_id | (is_id.not_().all().over("a") & pl.col("a").is_first_distinct())

But it's quite awkward and you still have to null out the non-matches in an extra step.

0
Hericks On

You can do the following.

  1. Explode columns a and b separately.
  2. Unnest column b into columns id and x.
  3. For each group defined by a, compute the corresponding value of x as follows.
    • pl.when(pl.col("a") == pl.col("id")).then("x") contains x, if a matches id, and None otherwise (still multiple values within each group).
    • pl.Expr.sort sorts these values and place null-values first.
    • Hence, pl.Expr.last selects the non-null value if it exists and None otherwise.
(
    df
    .explode("a").explode("b").unnest("b")
    .group_by("a", maintain_order=True)
    .agg(
        pl.when(pl.col("a") == pl.col("id")).then("x").sort().last()
    )
)

Output.

shape: (3, 2)
┌─────┬──────┐
│ a   ┆ x    │
│ --- ┆ ---  │
│ i64 ┆ i64  │
╞═════╪══════╡
│ 1   ┆ 1    │
│ 2   ┆ null │
│ 3   ┆ 4    │
└─────┴──────┘