How to filter a dataframe by row_id/row_number

58 Views Asked by At

I am looking to get a subset of rows based on the row_id/row_number for a dataframe similar to pyarrow.Table.take. For eg: given the below dataframe

from datetime import datetime

df = pl.DataFrame(
    {
        "integer": [1, 2, 3, 4, 5],
        "date": [
            datetime(2022, 1, 1),
            datetime(2022, 1, 2),
            datetime(2022, 1, 3),
            datetime(2022, 1, 4),
            datetime(2022, 1, 5),
        ],
        "float": [4.0, 5.0, 6.0, 7.0, 8.0],
    }
)

print(df)

shape: (5, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 2       ┆ 2022-01-02 00:00:00 ┆ 5.0   │
│ 3       ┆ 2022-01-03 00:00:00 ┆ 6.0   │
│ 4       ┆ 2022-01-04 00:00:00 ┆ 7.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

I am looking for take like function df.take([0, 4]) which gives the below dataframe.

shape: (2, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

The row numbers are a result of some other process and handed over. Tried using df.select(pl.all().take([take_indices]) and noticed that it was slower than actually running the filter directly. i.e. df.filter(filter_expr). Please note that I am doing this over extremely large datasets (> 100m rows).

Edit: Thanks for the answer. using df[[take_indices]] worked. However still curious as to why filter still outperforms the both select.gather as well as square bracket approach. Timings on my datasetwith 50m rows:

select.gather: .5s square_bracket: .32s [inline with mozway's timings] filter: .18s

1

There are 1 best solutions below

1
mozway On BEST ANSWER

df[[0,4]] will allow to select the indices 0 and 4.

Since take is deprecated, the equivalent to your proposed code would be to use gather:

df.select(pl.all().gather([0, 4]))

Output:

shape: (2, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

Timing on 500k rows:

# df.select(pl.all().gather([0, 4]))
145 µs ± 9.43 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# df[[0,4]]
122 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Timing on 5M rows:

# df.select(pl.all().gather([0, 4]))
150 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# df[[0,4]]
117 µs ± 17.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)