Pandera Python: Use DataFrameModel for validating multiindex columns

107 Views Asked by At

I'm trying to validate a multiindex dataframe with DataFrameModel


import pandera as pa
from pandera import Column, DataFrameSchema, Index
from pandera.typing import Index, Series

Test dataframe

df = pd.DataFrame(
    {
        ("here", "state"): ["NY", "FL", "GA", "CA"],
        ("here", "city"): ["New York", "Miami", "Atlanta", "San Francisco"],
        ("here", "price"): [8, 12, 10, 16],
    }
)

When i use:

zone_validator = DataFrameSchema(
    {
        ("here", "state"): Column(str),
        ("here", "price"): Column(int),
    }
)

zone_validator.validate(df)
print(df)

All is fine, but, when i use this like the documentation (https://pandera.readthedocs.io/en/stable/dataframe_models.html#multiindex), gives to me an error.

class MultiIndexSchema(pa.DataFrameModel):
    state: Series[str]
    city: Series[str]
    price: Series[int]

    class Config:
        # provide multi index options in the config
        multiindex_name = "here"
        multiindex_strict = True
        multiindex_coerce = True


schema = MultiIndexSchema
schema.validate(df)
 raise schema_error from original_exc
pandera.errors.SchemaError: column 'state' not in dataframe

   here                     
  state           city price
0    NY       New York     8
1    FL          Miami    12
2    GA        Atlanta    10
3    CA  San Francisco    16

What I'm doing wrong?

1

There are 1 best solutions below

0
Arigion On

You try to use a multiindex index for multiindex columns. Try using field aliases like this:

import pandas as pd
import pandera as pa
from pandera.typing import Series

df = pd.DataFrame(
    {
        ("here", "state"): ["NY", "FL", "GA", "CA"],
        ("here", "city"): ["New York", "Miami", "Atlanta", "San Francisco"],
        ("here", "price"): [8, 12, 10, 16],
    }
)


class MultiIndexSchema(pa.DataFrameModel):
    here_state: Series[str] = pa.Field(alias=("here", "state"))
    here_city: Series[str] = pa.Field(alias=("here", "city"))
    here_price: Series[int] = pa.Field(alias=("here", "price"))


schema = MultiIndexSchema
schema.validate(df)